From e7eac272414dd6458e111c5274e432df18f9f1b6 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Mon, 7 Jul 2025 17:48:01 -0400 Subject: [PATCH] ruff: more rules across the board & fixes (#31898) * standardizes ruff dep version across all `pyproject.toml` files * cli: ruff rules and corrections * langchain: rules and corrections --- Makefile | 1 - libs/cli/langchain_cli/cli.py | 16 +- libs/cli/langchain_cli/dev_scripts.py | 6 +- .../integration_template/Makefile | 2 +- libs/cli/langchain_cli/namespaces/app.py | 45 ++-- .../langchain_cli/namespaces/integration.py | 17 +- .../namespaces/migrate/generate/generic.py | 58 ++-- .../namespaces/migrate/generate/grit.py | 2 +- .../namespaces/migrate/generate/partner.py | 1 + .../namespaces/migrate/generate/utils.py | 17 +- .../langchain_cli/namespaces/migrate/main.py | 12 +- libs/cli/langchain_cli/namespaces/template.py | 26 +- libs/cli/langchain_cli/utils/events.py | 2 +- libs/cli/langchain_cli/utils/git.py | 11 +- libs/cli/langchain_cli/utils/github.py | 4 +- libs/cli/langchain_cli/utils/packages.py | 1 + libs/cli/langchain_cli/utils/pyproject.py | 8 +- libs/cli/pyproject.toml | 43 ++- libs/cli/scripts/generate_migrations.py | 12 +- .../unit_tests/migrate/cli_runner/file.py | 2 +- .../unit_tests/migrate/cli_runner/folder.py | 4 +- .../unit_tests/migrate/cli_runner/test_cli.py | 8 +- .../generate/test_langchain_migration.py | 21 +- .../generate/test_partner_migrations.py | 6 +- .../unit_tests/migrate/generate/test_utils.py | 4 +- libs/cli/tests/unit_tests/test_events.py | 7 +- libs/cli/tests/unit_tests/test_utils.py | 50 +++- libs/core/pyproject.toml | 4 +- libs/langchain/langchain/__init__.py | 40 ++- libs/langchain/langchain/agents/__init__.py | 4 +- libs/langchain/langchain/agents/agent.py | 109 +++++--- .../langchain/agents/agent_iterator.py | 43 ++- .../agent_toolkits/ainetwork/toolkit.py | 2 +- .../agents/agent_toolkits/amadeus/toolkit.py | 2 +- .../azure_cognitive_services.py | 2 +- .../agents/agent_toolkits/clickup/toolkit.py | 2 +- .../openai_functions.py | 6 +- .../file_management/__init__.py | 2 +- .../agent_toolkits/file_management/toolkit.py | 2 +- .../agents/agent_toolkits/gitlab/toolkit.py | 2 +- .../agents/agent_toolkits/json/base.py | 2 +- .../agents/agent_toolkits/multion/toolkit.py | 2 +- .../agent_toolkits/office365/toolkit.py | 2 +- .../agents/agent_toolkits/openapi/base.py | 2 +- .../agent_toolkits/playwright/__init__.py | 2 +- .../agent_toolkits/playwright/toolkit.py | 2 +- .../agents/agent_toolkits/powerbi/base.py | 2 +- .../agent_toolkits/powerbi/chat_base.py | 2 +- .../agents/agent_toolkits/powerbi/toolkit.py | 2 +- .../agents/agent_toolkits/spark_sql/base.py | 2 +- .../agent_toolkits/spark_sql/toolkit.py | 2 +- .../agents/agent_toolkits/sql/toolkit.py | 2 +- .../agent_toolkits/vectorstore/prompt.py | 6 +- .../agent_toolkits/vectorstore/toolkit.py | 9 +- .../agents/agent_toolkits/zapier/toolkit.py | 2 +- libs/langchain/langchain/agents/chat/base.py | 12 +- .../langchain/agents/chat/output_parser.py | 4 +- .../langchain/langchain/agents/chat/prompt.py | 7 +- .../langchain/agents/conversational/base.py | 14 +- .../agents/conversational/output_parser.py | 3 +- .../langchain/agents/conversational/prompt.py | 5 +- .../agents/conversational_chat/base.py | 12 +- .../agents/conversational_chat/prompt.py | 11 +- .../format_scratchpad/log_to_messages.py | 2 +- .../format_scratchpad/openai_functions.py | 6 +- .../agents/format_scratchpad/tools.py | 3 +- libs/langchain/langchain/agents/initialize.py | 10 +- .../langchain/agents/json_chat/base.py | 7 +- .../langchain/agents/json_chat/prompt.py | 5 +- libs/langchain/langchain/agents/load_tools.py | 3 +- libs/langchain/langchain/agents/loading.py | 13 +- libs/langchain/langchain/agents/mrkl/base.py | 7 +- .../langchain/agents/mrkl/output_parser.py | 10 +- .../langchain/langchain/agents/mrkl/prompt.py | 3 +- .../langchain/agents/openai_assistant/base.py | 49 ++-- .../agents/openai_functions_agent/base.py | 25 +- .../openai_functions_multi_agent/base.py | 19 +- .../langchain/agents/openai_tools/base.py | 8 +- .../agents/output_parsers/openai_functions.py | 8 +- .../agents/output_parsers/openai_tools.py | 7 +- .../output_parsers/react_json_single_input.py | 4 +- .../output_parsers/react_single_input.py | 7 +- .../langchain/agents/output_parsers/tools.py | 12 +- .../langchain/langchain/agents/react/agent.py | 2 +- .../agents/react/textworld_prompt.py | 1 - .../langchain/agents/react/wiki_prompt.py | 29 +- libs/langchain/langchain/agents/schema.py | 5 +- .../agents/self_ask_with_search/base.py | 10 +- .../agents/self_ask_with_search/prompt.py | 1 - .../langchain/agents/structured_chat/base.py | 13 +- .../agents/structured_chat/output_parser.py | 7 +- .../agents/structured_chat/prompt.py | 7 +- .../agents/tool_calling_agent/base.py | 4 +- libs/langchain/langchain/agents/xml/base.py | 4 +- libs/langchain/langchain/agents/xml/prompt.py | 3 +- .../langchain/callbacks/argilla_callback.py | 2 +- .../langchain/callbacks/arize_callback.py | 2 +- .../langchain/callbacks/arthur_callback.py | 2 +- .../langchain/callbacks/clearml_callback.py | 2 +- .../langchain/callbacks/comet_ml_callback.py | 2 +- .../langchain/callbacks/confident_callback.py | 2 +- .../langchain/callbacks/context_callback.py | 2 +- .../langchain/callbacks/flyte_callback.py | 2 +- .../langchain/callbacks/infino_callback.py | 2 +- .../langchain/callbacks/llmonitor_callback.py | 2 +- .../langchain/callbacks/openai_info.py | 2 +- .../callbacks/promptlayer_callback.py | 2 +- .../langchain/callbacks/sagemaker_callback.py | 2 +- .../langchain/callbacks/streaming_aiter.py | 5 +- .../callbacks/streaming_aiter_final_only.py | 5 +- .../callbacks/streaming_stdout_final_only.py | 5 +- .../langchain/callbacks/tracers/logging.py | 2 +- .../langchain/callbacks/trubrics_callback.py | 2 +- .../langchain/callbacks/wandb_callback.py | 2 +- .../langchain/callbacks/whylabs_callback.py | 2 +- libs/langchain/langchain/chains/api/base.py | 21 +- .../langchain/chains/api/news_docs.py | 3 +- .../langchain/chains/api/open_meteo_docs.py | 3 +- .../langchain/chains/api/podcast_docs.py | 3 +- libs/langchain/langchain/chains/api/prompt.py | 3 +- .../langchain/chains/api/tmdb_docs.py | 3 +- libs/langchain/langchain/chains/base.py | 36 ++- .../chains/chat_vector_db/prompts.py | 5 +- .../chains/combine_documents/base.py | 28 +- .../chains/combine_documents/map_reduce.py | 13 +- .../chains/combine_documents/map_rerank.py | 16 +- .../chains/combine_documents/reduce.py | 41 ++- .../chains/combine_documents/refine.py | 18 +- .../chains/combine_documents/stuff.py | 14 +- .../chains/constitutional_ai/base.py | 3 +- .../chains/constitutional_ai/principles.py | 47 ++-- .../chains/constitutional_ai/prompts.py | 53 ++-- .../langchain/chains/conversation/prompt.py | 10 +- .../chains/conversational_retrieval/base.py | 43 ++- .../conversational_retrieval/prompts.py | 5 +- .../chains/elasticsearch_database/base.py | 4 +- .../chains/elasticsearch_database/prompts.py | 5 +- .../langchain/chains/example_generator.py | 4 +- libs/langchain/langchain/chains/flare/base.py | 31 ++- libs/langchain/langchain/chains/hyde/base.py | 5 +- .../langchain/chains/hyde/prompts.py | 11 +- libs/langchain/langchain/chains/llm.py | 52 ++-- .../langchain/chains/llm_checker/base.py | 8 +- .../langchain/chains/llm_checker/prompt.py | 7 +- .../langchain/chains/llm_math/base.py | 8 +- .../langchain/chains/llm_math/prompt.py | 3 +- .../chains/llm_summarization_checker/base.py | 5 +- libs/langchain/langchain/chains/loading.py | 91 ++++--- libs/langchain/langchain/chains/mapreduce.py | 5 +- libs/langchain/langchain/chains/moderation.py | 4 +- .../langchain/langchain/chains/natbot/base.py | 7 +- .../langchain/chains/natbot/crawler.py | 134 ++++----- .../langchain/chains/natbot/prompt.py | 5 +- .../langchain/chains/openai_functions/base.py | 7 +- .../openai_functions/citation_fuzzy_match.py | 10 +- .../chains/openai_functions/extraction.py | 8 +- .../chains/openai_functions/openapi.py | 16 +- .../openai_functions/qa_with_structure.py | 12 +- .../chains/openai_tools/extraction.py | 2 +- .../langchain/chains/qa_generation/base.py | 5 +- .../langchain/chains/qa_generation/prompt.py | 10 +- .../langchain/chains/qa_with_sources/base.py | 18 +- .../chains/qa_with_sources/loading.py | 4 +- .../qa_with_sources/map_reduce_prompt.py | 17 +- .../chains/qa_with_sources/refine_prompts.py | 1 - .../chains/qa_with_sources/retrieval.py | 19 +- .../chains/qa_with_sources/stuff_prompt.py | 13 +- .../chains/qa_with_sources/vector_db.py | 19 +- .../chains/query_constructor/base.py | 31 ++- .../chains/query_constructor/parser.py | 2 +- .../chains/query_constructor/prompt.py | 11 +- .../chains/question_answering/chain.py | 11 +- .../question_answering/map_reduce_prompt.py | 28 +- .../question_answering/map_rerank_prompt.py | 6 +- .../question_answering/refine_prompts.py | 7 +- .../chains/question_answering/stuff_prompt.py | 10 +- .../langchain/chains/retrieval_qa/base.py | 36 ++- .../langchain/chains/retrieval_qa/prompt.py | 3 +- .../langchain/langchain/chains/router/base.py | 22 +- .../chains/router/embedding_router.py | 20 +- .../langchain/chains/router/llm_router.py | 5 +- .../langchain/chains/router/multi_prompt.py | 2 +- .../chains/router/multi_retrieval_qa.py | 9 +- libs/langchain/langchain/chains/sequential.py | 20 +- .../langchain/chains/sql_database/prompt.py | 30 +-- .../langchain/chains/sql_database/query.py | 2 +- .../chains/structured_output/base.py | 22 +- .../langchain/chains/summarize/chain.py | 4 +- .../chains/summarize/map_reduce_prompt.py | 1 - .../chains/summarize/stuff_prompt.py | 1 - libs/langchain/langchain/chains/transform.py | 2 +- .../chat_models/baidu_qianfan_endpoint.py | 2 +- libs/langchain/langchain/chat_models/base.py | 63 +++-- libs/langchain/langchain/chat_models/meta.py | 2 +- .../langchain/chat_models/pai_eas_endpoint.py | 2 +- .../chat_models/promptlayer_openai.py | 2 +- .../azure_blob_storage_container.py | 2 +- .../azure_blob_storage_file.py | 2 +- .../document_loaders/baiducloud_bos_file.py | 2 +- .../document_loaders/college_confidential.py | 2 +- .../langchain/document_loaders/markdown.py | 2 +- .../langchain/document_loaders/org_mode.py | 2 +- .../document_loaders/parsers/generic.py | 2 +- .../document_loaders/parsers/html/__init__.py | 2 +- .../document_loaders/parsers/html/bs4.py | 2 +- .../parsers/language/cobol.py | 2 +- .../parsers/language/python.py | 2 +- .../document_loaders/parsers/msword.py | 2 +- .../document_loaders/parsers/registry.py | 2 +- .../langchain/document_loaders/powerpoint.py | 2 +- .../document_loaders/pyspark_dataframe.py | 2 +- .../document_loaders/tencent_cos_directory.py | 2 +- .../beautiful_soup_transformer.py | 2 +- .../doctran_text_extract.py | 2 +- .../document_transformers/doctran_text_qa.py | 2 +- .../doctran_text_translate.py | 2 +- .../document_transformers/google_translate.py | 2 +- .../document_transformers/html2text.py | 2 +- .../nuclia_text_transform.py | 2 +- .../langchain/embeddings/__init__.py | 4 +- libs/langchain/langchain/embeddings/base.py | 4 +- libs/langchain/langchain/embeddings/cache.py | 17 +- .../agents/trajectory_eval_chain.py | 12 +- .../agents/trajectory_eval_prompt.py | 15 +- .../evaluation/comparison/eval_chain.py | 4 +- .../langchain/evaluation/comparison/prompt.py | 3 +- .../evaluation/criteria/eval_chain.py | 6 +- .../langchain/evaluation/criteria/prompt.py | 5 +- .../evaluation/embedding_distance/base.py | 13 +- .../langchain/langchain/evaluation/loading.py | 3 +- .../langchain/evaluation/qa/eval_prompt.py | 15 +- .../langchain/evaluation/qa/generate_chain.py | 3 +- .../evaluation/qa/generate_prompt.py | 6 +- libs/langchain/langchain/evaluation/schema.py | 10 +- .../evaluation/scoring/eval_chain.py | 4 +- .../langchain/evaluation/scoring/prompt.py | 1 - .../evaluation/string_distance/base.py | 7 +- .../langchain/indexes/_sql_record_manager.py | 24 +- .../indexes/prompts/entity_extraction.py | 3 +- .../indexes/prompts/entity_summarization.py | 3 +- .../prompts/knowledge_triplet_extraction.py | 4 +- .../langchain/indexes/vectorstore.py | 34 ++- .../langchain/llms/cloudflare_workersai.py | 2 +- libs/langchain/langchain/memory/buffer.py | 6 +- .../langchain/langchain/memory/chat_memory.py | 16 +- .../memory/chat_message_histories/astradb.py | 2 +- .../chat_message_histories/cassandra.py | 2 +- .../chat_message_histories/cosmos_db.py | 2 +- .../memory/chat_message_histories/dynamodb.py | 2 +- .../chat_message_histories/elasticsearch.py | 2 +- .../memory/chat_message_histories/file.py | 2 +- .../chat_message_histories/firestore.py | 2 +- .../memory/chat_message_histories/momento.py | 2 +- .../memory/chat_message_histories/mongodb.py | 2 +- .../memory/chat_message_histories/neo4j.py | 2 +- .../memory/chat_message_histories/postgres.py | 2 +- .../memory/chat_message_histories/redis.py | 2 +- .../chat_message_histories/rocksetdb.py | 2 +- .../chat_message_histories/singlestoredb.py | 2 +- .../chat_message_histories/streamlit.py | 2 +- .../chat_message_histories/upstash_redis.py | 2 +- .../memory/chat_message_histories/xata.py | 2 +- .../memory/chat_message_histories/zep.py | 2 +- libs/langchain/langchain/memory/combined.py | 5 +- libs/langchain/langchain/memory/entity.py | 10 +- libs/langchain/langchain/memory/prompt.py | 21 +- libs/langchain/langchain/memory/summary.py | 14 +- .../langchain/memory/summary_buffer.py | 22 +- .../langchain/langchain/memory/vectorstore.py | 17 +- .../memory/vectorstore_token_buffer_memory.py | 4 +- libs/langchain/langchain/model_laboratory.py | 4 +- .../langchain/output_parsers/__init__.py | 2 +- .../langchain/output_parsers/datetime.py | 4 +- .../langchain/langchain/output_parsers/fix.py | 8 +- .../output_parsers/format_instructions.py | 14 +- .../output_parsers/pandas_dataframe.py | 15 +- .../langchain/output_parsers/prompts.py | 3 +- .../langchain/output_parsers/rail_parser.py | 2 +- .../langchain/output_parsers/retry.py | 13 +- .../langchain/output_parsers/structured.py | 9 +- .../langchain/output_parsers/yaml.py | 3 +- libs/langchain/langchain/prompts/__init__.py | 4 +- .../prompts/example_selector/__init__.py | 4 +- .../retrievers/contextual_compression.py | 16 +- .../retrievers/document_compressors/base.py | 12 +- .../document_compressors/chain_extract.py | 4 +- .../chain_extract_prompt.py | 5 +- .../document_compressors/chain_filter.py | 6 +- .../chain_filter_prompt.py | 3 +- .../document_compressors/cohere_rerank.py | 21 +- .../document_compressors/embeddings_filter.py | 10 +- .../document_compressors/flashrank_rerank.py | 2 +- .../document_compressors/listwise_rerank.py | 5 +- .../langchain/retrievers/ensemble.py | 22 +- .../google_cloud_documentai_warehouse.py | 2 +- .../langchain/retrievers/merger_retriever.py | 10 +- .../langchain/retrievers/multi_query.py | 30 ++- .../langchain/retrievers/multi_vector.py | 25 +- .../retrievers/parent_document_retriever.py | 8 +- .../langchain/retrievers/re_phraser.py | 6 +- .../langchain/retrievers/self_query/base.py | 34 ++- .../retrievers/time_weighted_retriever.py | 23 +- .../langchain/smith/evaluation/config.py | 11 +- .../langchain/smith/evaluation/progress.py | 6 +- .../smith/evaluation/runner_utils.py | 79 ++++-- .../smith/evaluation/string_run_evaluator.py | 12 +- .../langchain/storage/encoder_backed.py | 8 +- .../tools/google_finance/__init__.py | 2 +- .../langchain/tools/google_finance/tool.py | 2 +- .../tools/google_scholar/__init__.py | 2 +- .../langchain/tools/google_scholar/tool.py | 2 +- .../langchain/tools/google_trends/__init__.py | 2 +- .../langchain/tools/google_trends/tool.py | 2 +- .../utilities/dalle_image_generator.py | 2 +- .../utilities/dataforseo_api_search.py | 2 +- .../langchain/utilities/reddit_search.py | 2 +- .../langchain/utilities/tavily_search.py | 2 +- libs/langchain/pyproject.toml | 62 ++++- .../cache/fake_embeddings.py | 2 +- .../chains/openai_functions/test_openapi.py | 4 +- .../chat_models/test_base.py | 6 +- .../embedding_distance/test_embedding.py | 7 +- .../test_listwise_rerank.py | 3 +- .../agents/format_scratchpad/test_log.py | 4 +- .../format_scratchpad/test_log_to_messages.py | 7 +- .../test_openai_functions.py | 21 +- .../format_scratchpad/test_openai_tools.py | 6 +- .../output_parsers/test_openai_functions.py | 4 +- .../test_react_json_single_input.py | 7 +- .../output_parsers/test_react_single_input.py | 7 +- .../agents/output_parsers/test_self_ask.py | 12 +- .../tests/unit_tests/agents/test_agent.py | 170 ++++++------ .../unit_tests/agents/test_agent_async.py | 20 +- .../unit_tests/agents/test_agent_iterator.py | 42 ++- .../tests/unit_tests/agents/test_mrkl.py | 2 +- .../agents/test_openai_functions_multi.py | 4 +- .../unit_tests/agents/test_structured_chat.py | 10 +- .../callbacks/tracers/test_logging.py | 3 +- .../tests/unit_tests/chains/test_base.py | 3 +- .../chains/test_combine_documents.py | 6 +- .../unit_tests/chains/test_conversation.py | 2 +- .../chains/test_conversation_retrieval.py | 24 +- .../chains/test_history_aware_retriever.py | 2 +- .../tests/unit_tests/chains/test_hyde.py | 8 +- .../unit_tests/chains/test_llm_checker.py | 2 +- .../chains/test_llm_summarization_checker.py | 2 - .../tests/unit_tests/chains/test_memory.py | 2 +- .../unit_tests/chains/test_qa_with_sources.py | 4 +- .../unit_tests/chains/test_sequential.py | 24 +- .../chains/test_summary_buffer_memory.py | 8 +- .../tests/unit_tests/chat_models/test_base.py | 28 +- libs/langchain/tests/unit_tests/conftest.py | 16 +- .../tests/unit_tests/embeddings/test_base.py | 9 +- .../unit_tests/embeddings/test_caching.py | 24 +- .../evaluation/agents/test_eval_chain.py | 14 +- .../evaluation/criteria/test_eval_chain.py | 13 +- .../evaluation/exact_match/test_base.py | 12 +- .../evaluation/parsing/test_base.py | 8 +- .../evaluation/parsing/test_json_distance.py | 18 +- .../evaluation/parsing/test_json_schema.py | 9 +- .../evaluation/qa/test_eval_chain.py | 13 +- .../evaluation/regex_match/test_base.py | 9 +- .../evaluation/string_distance/test_base.py | 22 +- .../tests/unit_tests/indexes/test_indexing.py | 254 ++++++++++++------ .../tests/unit_tests/llms/fake_chat_model.py | 19 +- .../unit_tests/llms/test_fake_chat_model.py | 19 +- .../tests/unit_tests/load/test_dump.py | 3 +- .../tests/unit_tests/load/test_load.py | 10 +- .../unit_tests/memory/test_combined_memory.py | 5 +- .../output_parsers/test_combining_parser.py | 10 +- .../unit_tests/output_parsers/test_fix.py | 7 +- .../unit_tests/output_parsers/test_json.py | 32 ++- .../test_pandas_dataframe_parser.py | 2 +- .../output_parsers/test_regex_dict.py | 6 +- .../unit_tests/output_parsers/test_retry.py | 14 +- .../output_parsers/test_yaml_parser.py | 7 +- .../test_chain_extract.py | 16 +- .../document_compressors/test_chain_filter.py | 6 +- .../retrievers/self_query/test_base.py | 21 +- .../unit_tests/retrievers/test_ensemble.py | 12 +- .../retrievers/test_multi_vector.py | 18 +- .../retrievers/test_parent_document.py | 8 +- .../test_time_weighted_retriever.py | 16 +- .../tests/unit_tests/runnables/test_hub.py | 8 +- .../runnables/test_openai_functions.py | 13 +- .../smith/evaluation/test_runner_utils.py | 8 +- .../unit_tests/storage/test_filesystem.py | 7 +- libs/langchain/tests/unit_tests/stubs.py | 4 +- .../tests/unit_tests/test_dependencies.py | 6 +- .../langchain/tests/unit_tests/test_schema.py | 2 +- .../tests/unit_tests/utils/test_iter.py | 4 +- .../unit_tests/utils/test_openai_functions.py | 2 +- libs/partners/anthropic/pyproject.toml | 2 +- libs/partners/chroma/pyproject.toml | 2 +- libs/partners/deepseek/pyproject.toml | 2 +- libs/partners/exa/pyproject.toml | 2 +- libs/partners/fireworks/pyproject.toml | 2 +- libs/partners/groq/pyproject.toml | 2 +- libs/partners/huggingface/pyproject.toml | 2 +- libs/partners/mistralai/pyproject.toml | 2 +- libs/partners/nomic/pyproject.toml | 2 +- libs/partners/ollama/pyproject.toml | 2 +- libs/partners/openai/pyproject.toml | 2 +- libs/partners/perplexity/pyproject.toml | 2 +- libs/partners/prompty/pyproject.toml | 2 +- libs/partners/qdrant/pyproject.toml | 2 +- libs/partners/xai/pyproject.toml | 2 +- pyproject.toml | 4 +- 408 files changed, 2783 insertions(+), 1671 deletions(-) diff --git a/Makefile b/Makefile index f5e750c6371..62b97f600a7 100644 --- a/Makefile +++ b/Makefile @@ -71,7 +71,6 @@ spell_fix: lint lint_package lint_tests: uv run --group lint ruff check docs cookbook uv run --group lint ruff format docs cookbook cookbook --diff - uv run --group lint ruff check docs cookbook git --no-pager grep 'from langchain import' docs cookbook | grep -vE 'from langchain import (hub)' && echo "Error: no importing langchain from root in docs, except for hub" && exit 1 || exit 0 git --no-pager grep 'api.python.langchain.com' -- docs/docs ':!docs/docs/additional_resources/arxiv_references.mdx' ':!docs/docs/integrations/document_loaders/sitemap.ipynb' || exit 0 && \ diff --git a/libs/cli/langchain_cli/cli.py b/libs/cli/langchain_cli/cli.py index b60035d18a7..d7a500d593b 100644 --- a/libs/cli/langchain_cli/cli.py +++ b/libs/cli/langchain_cli/cli.py @@ -11,7 +11,9 @@ from langchain_cli.utils.packages import get_langserve_export, get_package_root app = typer.Typer(no_args_is_help=True, add_completion=False) app.add_typer( - template_namespace.package_cli, name="template", help=template_namespace.__doc__ + template_namespace.package_cli, + name="template", + help=template_namespace.__doc__, ) app.add_typer(app_namespace.app_cli, name="app", help=app_namespace.__doc__) app.add_typer( @@ -32,7 +34,7 @@ app.command( ) -def version_callback(show_version: bool) -> None: +def version_callback(show_version: bool) -> None: # noqa: FBT001 if show_version: typer.echo(f"langchain-cli {__version__}") raise typer.Exit @@ -40,8 +42,8 @@ def version_callback(show_version: bool) -> None: @app.callback() def main( - version: bool = typer.Option( - False, + version: bool = typer.Option( # noqa: FBT001 + False, # noqa: FBT003 "--version", "-v", help="Print the current CLI version.", @@ -56,10 +58,12 @@ def main( def serve( *, port: Annotated[ - Optional[int], typer.Option(help="The port to run the server on") + Optional[int], + typer.Option(help="The port to run the server on"), ] = None, host: Annotated[ - Optional[str], typer.Option(help="The host to run the server on") + Optional[str], + typer.Option(help="The host to run the server on"), ] = None, ) -> None: """Start the LangServe app, whether it's a template or an app.""" diff --git a/libs/cli/langchain_cli/dev_scripts.py b/libs/cli/langchain_cli/dev_scripts.py index 31c4259339c..605ea949f92 100644 --- a/libs/cli/langchain_cli/dev_scripts.py +++ b/libs/cli/langchain_cli/dev_scripts.py @@ -1,7 +1,7 @@ -# type: ignore """Development Scripts for template packages.""" from collections.abc import Sequence +from typing import Literal from fastapi import FastAPI from langserve import add_routes @@ -12,9 +12,9 @@ from langchain_cli.utils.packages import get_langserve_export, get_package_root def create_demo_server( *, config_keys: Sequence[str] = (), - playground_type: str = "default", + playground_type: Literal["default", "chat"] = "default", ): - """Creates a demo server for the current template.""" + """Create a demo server for the current template.""" app = FastAPI() package_root = get_package_root() pyproject = package_root / "pyproject.toml" diff --git a/libs/cli/langchain_cli/integration_template/Makefile b/libs/cli/langchain_cli/integration_template/Makefile index 4c4bbf237d9..f660290fd1b 100644 --- a/libs/cli/langchain_cli/integration_template/Makefile +++ b/libs/cli/langchain_cli/integration_template/Makefile @@ -39,7 +39,7 @@ lint lint_diff lint_package lint_tests: format format_diff: [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) - [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --fix $(PYTHON_FILES) spell_check: poetry run codespell --toml pyproject.toml diff --git a/libs/cli/langchain_cli/namespaces/app.py b/libs/cli/langchain_cli/namespaces/app.py index a259cf48aa2..a2eefe07ac5 100644 --- a/libs/cli/langchain_cli/namespaces/app.py +++ b/libs/cli/langchain_cli/namespaces/app.py @@ -79,7 +79,9 @@ def new( package_prompt = "What package would you like to add? (leave blank to skip)" while True: package_str = typer.prompt( - package_prompt, default="", show_default=False + package_prompt, + default="", + show_default=False, ) if not package_str: break @@ -121,26 +123,29 @@ def new( typer.echo("Then add templates with commands like:\n") typer.echo(" langchain app add extraction-openai-functions") typer.echo( - " langchain app add git+ssh://git@github.com/efriis/simple-pirate.git\n\n" + " langchain app add git+ssh://git@github.com/efriis/simple-pirate.git\n\n", ) @app_cli.command() def add( dependencies: Annotated[ - Optional[list[str]], typer.Argument(help="The dependency to add") + Optional[list[str]], + typer.Argument(help="The dependency to add"), ] = None, *, api_path: Annotated[list[str], typer.Option(help="API paths to add")] = [], project_dir: Annotated[ - Optional[Path], typer.Option(help="The project directory") + Optional[Path], + typer.Option(help="The project directory"), ] = None, repo: Annotated[ list[str], typer.Option(help="Install templates from a specific github repo instead"), ] = [], branch: Annotated[ - list[str], typer.Option(help="Install templates from a specific branch") + list[str], + typer.Option(help="Install templates from a specific branch"), ] = [], pip: Annotated[ bool, @@ -152,13 +157,12 @@ def add( ), ], ) -> None: - """Adds the specified template to the current LangServe app. + """Add the specified template to the current LangServe app. e.g.: langchain app add extraction-openai-functions langchain app add git+ssh://git@github.com/efriis/simple-pirate.git """ - if not branch and not repo: warnings.warn( "Adding templates from the default branch and repo is deprecated." @@ -173,7 +177,7 @@ def add( package_dir = project_root / "packages" create_events( - [{"event": "serve add", "properties": {"parsed_dep": d}} for d in parsed_deps] + [{"event": "serve add", "properties": {"parsed_dep": d}} for d in parsed_deps], ) # group by repo/ref @@ -248,7 +252,7 @@ def add( cmd = ["pip", "install", "-e", *installed_destination_strs] cmd_str = " \\\n ".join(installed_destination_strs) typer.echo(f"Running: pip install -e \\\n {cmd_str}") - subprocess.run(cmd, cwd=cwd) + subprocess.run(cmd, cwd=cwd) # noqa: S603 chain_names = [] for e in installed_exports: @@ -296,10 +300,11 @@ def remove( api_paths: Annotated[list[str], typer.Argument(help="The API paths to remove")], *, project_dir: Annotated[ - Optional[Path], typer.Option(help="The project directory") + Optional[Path], + typer.Option(help="The project directory"), ] = None, ) -> None: - """Removes the specified package from the current LangServe app.""" + """Remove the specified package from the current LangServe app.""" project_root = get_package_root(project_dir) project_pyproject = project_root / "pyproject.toml" @@ -320,7 +325,7 @@ def remove( shutil.rmtree(package_dir) remove_deps.append(api_path) - except Exception: + except Exception: # noqa: S110 pass try: @@ -334,16 +339,19 @@ def remove( def serve( *, port: Annotated[ - Optional[int], typer.Option(help="The port to run the server on") + Optional[int], + typer.Option(help="The port to run the server on"), ] = None, host: Annotated[ - Optional[str], typer.Option(help="The host to run the server on") + Optional[str], + typer.Option(help="The host to run the server on"), ] = None, app: Annotated[ - Optional[str], typer.Option(help="The app to run, e.g. `app.server:app`") + Optional[str], + typer.Option(help="The app to run, e.g. `app.server:app`"), ] = None, ) -> None: - """Starts the LangServe app.""" + """Start the LangServe app.""" # add current dir as first entry of path sys.path.append(str(Path.cwd())) @@ -353,5 +361,8 @@ def serve( import uvicorn uvicorn.run( - app_str, host=host_str, port=port if port is not None else 8000, reload=True + app_str, + host=host_str, + port=port if port is not None else 8000, + reload=True, ) diff --git a/libs/cli/langchain_cli/namespaces/integration.py b/libs/cli/langchain_cli/namespaces/integration.py index e92ed283848..9f4cfef574d 100644 --- a/libs/cli/langchain_cli/namespaces/integration.py +++ b/libs/cli/langchain_cli/namespaces/integration.py @@ -66,7 +66,7 @@ def new( Optional[str], typer.Option( help="The name of the integration in PascalCase. e.g. `MyIntegration`." - " This is used to name classes like `MyIntegrationVectorStore`" + " This is used to name classes like `MyIntegrationVectorStore`", ), ] = None, src: Annotated[ @@ -85,7 +85,7 @@ def new( ), ] = None, ) -> None: - """Creates a new integration package.""" + """Create a new integration package.""" try: replacements = _process_name(name) except ValueError as e: @@ -96,13 +96,14 @@ def new( if not re.match(r"^[A-Z][a-zA-Z0-9]*$", name_class): typer.echo( "Name should only contain letters (a-z, A-Z), numbers, and underscores" - ", and start with a capital letter." + ", and start with a capital letter.", ) raise typer.Exit(code=1) replacements["__ModuleName__"] = name_class else: replacements["__ModuleName__"] = typer.prompt( - "Name of integration in PascalCase", default=replacements["__ModuleName__"] + "Name of integration in PascalCase", + default=replacements["__ModuleName__"], ) project_template_dir = Path(__file__).parents[1] / "integration_template" @@ -124,7 +125,7 @@ def new( # poetry install subprocess.run( - ["poetry", "install", "--with", "lint,test,typing,test_integration"], + ["poetry", "install", "--with", "lint,test,typing,test_integration"], # noqa: S607 cwd=destination_dir, ) else: @@ -152,7 +153,7 @@ def new( if len(dst_paths) != len(set(dst_paths)): typer.echo( "Duplicate destination paths provided or computed - please " - "specify them explicitly with --dst." + "specify them explicitly with --dst.", ) raise typer.Exit(code=1) @@ -224,10 +225,10 @@ def create_doc( ), ] = "docs/docs/integrations/chat/", ) -> None: - """Creates a new integration doc.""" + """Create a new integration doc.""" if component_type not in TEMPLATE_MAP: typer.echo( - f"Unrecognized {component_type=}. Expected one of {_component_types_str}." + f"Unrecognized {component_type=}. Expected one of {_component_types_str}.", ) raise typer.Exit(code=1) diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py b/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py index e2fa0fa4879..24b866c7a70 100644 --- a/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py @@ -6,14 +6,17 @@ import pkgutil def generate_raw_migrations( - from_package: str, to_package: str, filter_by_all: bool = False + from_package: str, + to_package: str, + filter_by_all: bool = False, # noqa: FBT001, FBT002 ) -> list[tuple[str, str]]: """Scan the `langchain` package and generate migrations for all modules.""" package = importlib.import_module(from_package) items = [] for _importer, modname, _ispkg in pkgutil.walk_packages( - package.__path__, package.__name__ + "." + package.__path__, + package.__name__ + ".", ): try: module = importlib.import_module(modname) @@ -34,33 +37,35 @@ def generate_raw_migrations( obj = getattr(module, name, None) except ImportError: continue - if obj and (inspect.isclass(obj) or inspect.isfunction(obj)): - if obj.__module__.startswith(to_package): - items.append( - (f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}") - ) + if ( + obj + and (inspect.isclass(obj) or inspect.isfunction(obj)) + and obj.__module__.startswith(to_package) + ): + items.append( + (f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}"), + ) if not filter_by_all: # Iterate over all members of the module for name, obj in inspect.getmembers(module): # Check if it's a class or function - if inspect.isclass(obj) or inspect.isfunction(obj): - # Check if the module name of the obj starts with - # 'langchain_community' - if obj.__module__.startswith(to_package): - items.append( - (f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}") - ) + # Check if the module name of the obj starts with + # 'langchain_community' + if inspect.isclass(obj) or ( + inspect.isfunction(obj) and obj.__module__.startswith(to_package) + ): + items.append( + (f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}"), + ) return items def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]: - """This code will look at all the top level modules in langchain_community. + """Look at all the top level modules in langchain_community. - It'll attempt to import everything from each __init__ file - - for example, + Attempt to import everything from each ``__init__`` file. For example, langchain_community/ chat_models/ @@ -74,10 +79,10 @@ def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]: Each tuple will contain the fully qualified path of the class / function to where its logic is defined - (e.g., langchain_community.chat_models.xyz_implementation.ver2.XYZ) + (e.g., ``langchain_community.chat_models.xyz_implementation.ver2.XYZ``) and the second tuple will contain the path to importing it from the top level namespaces - (e.g., langchain_community.chat_models.XYZ) + (e.g., ``langchain_community.chat_models.XYZ``) """ package = importlib.import_module(pkg) @@ -98,7 +103,7 @@ def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]: top_level_import = f"{module_name}.{name}" # Append the tuple with original and top-level paths items.append( - (f"{original_module}.{original_name}", top_level_import) + (f"{original_module}.{original_name}", top_level_import), ) # Handle the package itself (root level) @@ -106,7 +111,8 @@ def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]: # Only iterate through top-level modules/packages for _finder, modname, ispkg in pkgutil.iter_modules( - package.__path__, package.__name__ + "." + package.__path__, + package.__name__ + ".", ): if ispkg: try: @@ -119,11 +125,15 @@ def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]: def generate_simplified_migrations( - from_package: str, to_package: str, filter_by_all: bool = True + from_package: str, + to_package: str, + filter_by_all: bool = True, # noqa: FBT001, FBT002 ) -> list[tuple[str, str]]: """Get all the raw migrations, then simplify them if possible.""" raw_migrations = generate_raw_migrations( - from_package, to_package, filter_by_all=filter_by_all + from_package, + to_package, + filter_by_all=filter_by_all, ) top_level_simplifications = generate_top_level_imports(to_package) top_level_dict = dict(top_level_simplifications) diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/grit.py b/libs/cli/langchain_cli/namespaces/migrate/generate/grit.py index 3da2c9933b0..501b9f74426 100644 --- a/libs/cli/langchain_cli/namespaces/migrate/generate/grit.py +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/grit.py @@ -17,7 +17,7 @@ def dump_migrations_as_grit(name: str, migration_pairs: list[tuple[str, str]]) - ] """ for from_module, to_module in migration_pairs - ] + ], ) pattern_name = f"langchain_migrate_{name}" return f""" diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py b/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py index ef8a34409db..912548ddde1 100644 --- a/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py @@ -28,6 +28,7 @@ def get_migrations_for_partner_package(pkg_name: str) -> list[tuple[str, str]]: Returns: List of 2-tuples containing old and new import paths. + """ package = importlib.import_module(pkg_name) classes_ = find_subclasses_in_module( diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py b/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py index 2fbd65f8b8f..3ea55f015c3 100644 --- a/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py @@ -20,7 +20,7 @@ class ImportExtractor(ast.NodeVisitor): self.imports: list = [] self.package = from_package - def visit_ImportFrom(self, node) -> None: + def visit_ImportFrom(self, node) -> None: # noqa: N802 if node.module and ( self.package is None or str(node.module).startswith(self.package) ): @@ -39,7 +39,7 @@ def _get_class_names(code: str) -> list[str]: # Define a node visitor class to collect class names class ClassVisitor(ast.NodeVisitor): - def visit_ClassDef(self, node) -> None: + def visit_ClassDef(self, node) -> None: # noqa: N802 class_names.append(node.name) self.generic_visit(node) @@ -79,7 +79,9 @@ def _get_all_classnames_from_file(file: Path, pkg: str) -> list[tuple[str, str]] def identify_all_imports_in_file( - file: str, *, from_package: Optional[str] = None + file: str, + *, + from_package: Optional[str] = None, ) -> list[tuple[str, str]]: """Let's also identify all the imports in the given file.""" with open(file, encoding="utf-8") as f: @@ -96,10 +98,13 @@ def identify_pkg_source(pkg_root: str) -> pathlib.Path: Returns: Returns the path to the source code for the package. + """ dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()] matching_dirs = [d for d in dirs if d.name.startswith("langchain_")] - assert len(matching_dirs) == 1, "There should be only one langchain package." + if len(matching_dirs) != 1: + msg = "There should be only one langchain package." + raise ValueError(msg) return matching_dirs[0] @@ -134,7 +139,9 @@ def list_init_imports_by_package(pkg_root: str) -> list[tuple[str, str]]: def find_imports_from_package( - code: str, *, from_package: Optional[str] = None + code: str, + *, + from_package: Optional[str] = None, ) -> list[tuple[str, str]]: # Parse the code into an AST tree = ast.parse(code) diff --git a/libs/cli/langchain_cli/namespaces/migrate/main.py b/libs/cli/langchain_cli/namespaces/migrate/main.py index f875d344f08..54160fb4c12 100644 --- a/libs/cli/langchain_cli/namespaces/migrate/main.py +++ b/libs/cli/langchain_cli/namespaces/migrate/main.py @@ -4,7 +4,7 @@ from pathlib import Path import rich import typer -from gritql import run # type: ignore +from gritql import run # type: ignore[import] from typer import Option @@ -17,13 +17,13 @@ def get_gritdir_path() -> Path: def migrate( ctx: typer.Context, # Using diff instead of dry-run for backwards compatibility with the old CLI - diff: bool = Option( - False, + diff: bool = Option( # noqa: FBT001 + False, # noqa: FBT003 "--diff", help="Show the changes that would be made without applying them.", ), - interactive: bool = Option( - False, + interactive: bool = Option( # noqa: FBT001 + False, # noqa: FBT003 "--interactive", help="Prompt for confirmation before making each change", ), @@ -54,7 +54,7 @@ def migrate( '⚠️ This script is a "best-effort", and is likely to make some ' "mistakes.\n\n" "🛡️ Backup your code prior to running the migration script -- it will " - "modify your files!\n\n" + "modify your files!\n\n", ) rich.print("-" * 10) rich.print() diff --git a/libs/cli/langchain_cli/namespaces/template.py b/libs/cli/langchain_cli/namespaces/template.py index 1f348e35178..7ad5c14f494 100644 --- a/libs/cli/langchain_cli/namespaces/template.py +++ b/libs/cli/langchain_cli/namespaces/template.py @@ -16,12 +16,12 @@ package_cli = typer.Typer(no_args_is_help=True, add_completion=False) @package_cli.command() def new( name: Annotated[str, typer.Argument(help="The name of the folder to create")], - with_poetry: Annotated[ + with_poetry: Annotated[ # noqa: FBT002 bool, typer.Option("--with-poetry/--no-poetry", help="Don't run poetry install"), ] = False, ) -> None: - """Creates a new template package.""" + """Create a new template package.""" computed_name = name if name != "." else Path.cwd().name destination_dir = Path.cwd() / name if name != "." else Path.cwd() @@ -53,8 +53,9 @@ def new( pyproject_contents = pyproject.read_text() pyproject.write_text( pyproject_contents.replace("__package_name__", package_name).replace( - "__module_name__", module_name - ) + "__module_name__", + module_name, + ), ) # move module folder @@ -71,23 +72,26 @@ def new( readme_contents = readme.read_text() readme.write_text( readme_contents.replace("__package_name__", package_name).replace( - "__app_route_code__", app_route_code - ) + "__app_route_code__", + app_route_code, + ), ) # poetry install if with_poetry: - subprocess.run(["poetry", "install"], cwd=destination_dir) + subprocess.run(["poetry", "install"], cwd=destination_dir) # noqa: S607 @package_cli.command() def serve( *, port: Annotated[ - Optional[int], typer.Option(help="The port to run the server on") + Optional[int], + typer.Option(help="The port to run the server on"), ] = None, host: Annotated[ - Optional[str], typer.Option(help="The host to run the server on") + Optional[str], + typer.Option(help="The host to run the server on"), ] = None, configurable: Annotated[ Optional[bool], @@ -104,7 +108,7 @@ def serve( ), ] = False, ) -> None: - """Starts a demo app for this template.""" + """Start a demo app for this template.""" # load pyproject.toml project_dir = get_package_root() pyproject = project_dir / "pyproject.toml" @@ -136,7 +140,7 @@ def serve( @package_cli.command() -def list(contains: Annotated[Optional[str], typer.Argument()] = None) -> None: +def list(contains: Annotated[Optional[str], typer.Argument()] = None) -> None: # noqa: A001 """List all or search for available templates.""" from langchain_cli.utils.github import list_packages diff --git a/libs/cli/langchain_cli/utils/events.py b/libs/cli/langchain_cli/utils/events.py index 3518634f60e..695bc3ca447 100644 --- a/libs/cli/langchain_cli/utils/events.py +++ b/libs/cli/langchain_cli/utils/events.py @@ -20,7 +20,7 @@ def create_events(events: list[EventDict]) -> Optional[Any]: "properties": event.get("properties"), } for event in events - ] + ], } conn = http.client.HTTPSConnection("app.firstpartyhq.com") diff --git a/libs/cli/langchain_cli/utils/git.py b/libs/cli/langchain_cli/utils/git.py index cef0ba639bc..f45a2889270 100644 --- a/libs/cli/langchain_cli/utils/git.py +++ b/libs/cli/langchain_cli/utils/git.py @@ -122,7 +122,9 @@ def parse_dependencies( api_path: list[str], ) -> list[DependencySource]: num_deps = max( - len(dependencies) if dependencies is not None else 0, len(repo), len(branch) + len(dependencies) if dependencies is not None else 0, + len(repo), + len(branch), ) if ( (dependencies and len(dependencies) != num_deps) @@ -143,7 +145,10 @@ def parse_dependencies( return [ parse_dependency_string(iter_dep, iter_repo, iter_branch, iter_api_path) for iter_dep, iter_repo, iter_branch, iter_api_path in zip( - inner_deps, inner_repos, inner_branches, inner_api_paths + inner_deps, + inner_repos, + inner_branches, + inner_api_paths, ) ] @@ -186,7 +191,7 @@ def copy_repo( source: Path, destination: Path, ) -> None: - """Copies a repo, ignoring git folders. + """Copiy a repo, ignoring git folders. Raises FileNotFound error if it can't find source """ diff --git a/libs/cli/langchain_cli/utils/github.py b/libs/cli/langchain_cli/utils/github.py index fe2c6b3d638..0c3ef3118d2 100644 --- a/libs/cli/langchain_cli/utils/github.py +++ b/libs/cli/langchain_cli/utils/github.py @@ -13,7 +13,9 @@ def list_packages(*, contains: Optional[str] = None): } conn.request( - "GET", "/repos/langchain-ai/langchain/contents/templates", headers=headers + "GET", + "/repos/langchain-ai/langchain/contents/templates", + headers=headers, ) res = conn.getresponse() diff --git a/libs/cli/langchain_cli/utils/packages.py b/libs/cli/langchain_cli/utils/packages.py index c8d00722579..f45caec1ff2 100644 --- a/libs/cli/langchain_cli/utils/packages.py +++ b/libs/cli/langchain_cli/utils/packages.py @@ -26,6 +26,7 @@ class LangServeExport(TypedDict): module: The module to import from, tool.langserve.export_module attr: The attribute to import from the module, tool.langserve.export_attr package_name: The name of the package, tool.poetry.name + """ module: str diff --git a/libs/cli/langchain_cli/utils/pyproject.py b/libs/cli/langchain_cli/utils/pyproject.py index 61a86c45a93..237676d636f 100644 --- a/libs/cli/langchain_cli/utils/pyproject.py +++ b/libs/cli/langchain_cli/utils/pyproject.py @@ -14,7 +14,8 @@ def _get_dep_inline_table(path: Path) -> InlineTable: def add_dependencies_to_pyproject_toml( - pyproject_toml: Path, local_editable_dependencies: Iterable[tuple[str, Path]] + pyproject_toml: Path, + local_editable_dependencies: Iterable[tuple[str, Path]], ) -> None: """Add dependencies to pyproject.toml.""" with open(pyproject_toml, encoding="utf-8") as f: @@ -24,14 +25,15 @@ def add_dependencies_to_pyproject_toml( { name: _get_dep_inline_table(loc.relative_to(pyproject_toml.parent)) for name, loc in local_editable_dependencies - } + }, ) with open(pyproject_toml, "w", encoding="utf-8") as f: dump(pyproject, f) def remove_dependencies_from_pyproject_toml( - pyproject_toml: Path, local_editable_dependencies: Iterable[str] + pyproject_toml: Path, + local_editable_dependencies: Iterable[str], ) -> None: """Remove dependencies from pyproject.toml.""" with open(pyproject_toml, encoding="utf-8") as f: diff --git a/libs/cli/pyproject.toml b/libs/cli/pyproject.toml index 1e9d78e3174..a989d304c03 100644 --- a/libs/cli/pyproject.toml +++ b/libs/cli/pyproject.toml @@ -49,12 +49,45 @@ exclude = [ [tool.ruff.lint] select = [ - "E", # pycodestyle - "F", # pyflakes - "I", # isort - "T201", # print - "UP", # pyupgrade + "A", # flake8-builtins + "ARG", # flake8-unused-arguments + "ASYNC", # flake8-async + "C4", # flake8-comprehensions + "COM", # flake8-commas + "D", # pydocstyle + "DOC", # pydoclint + "E", # pycodestyle error + "EM", # flake8-errmsg + "F", # pyflakes + "FA", # flake8-future-annotations + "FBT", # flake8-boolean-trap + "FLY", # flake8-flynt + "I", # isort + "ICN", # flake8-import-conventions + "INT", # flake8-gettext + "ISC", # isort-comprehensions + "N", # pep8-naming + "PT", # flake8-pytest-style + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PERF", # flake8-perf + "PYI", # flake8-pyi + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-rst-docstrings + "RUF", # ruff + "S", # flake8-bandit + "SLF", # flake8-self + "SLOT", # flake8-slots + "SIM", # flake8-simplify + "T10", # flake8-debugger + "T20", # flake8-print + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle warning + "YTT", # flake8-2020 ] +ignore = ["D100", "D101", "D102", "D103", "D104", "D105", "D107"] pyupgrade.keep-runtime-typing = true [tool.mypy] diff --git a/libs/cli/scripts/generate_migrations.py b/libs/cli/scripts/generate_migrations.py index 1dcd5650a2b..f5f915dafbc 100644 --- a/libs/cli/scripts/generate_migrations.py +++ b/libs/cli/scripts/generate_migrations.py @@ -1,4 +1,3 @@ -# type: ignore """Script to generate migrations for the migration script.""" import json @@ -45,12 +44,17 @@ def cli() -> None: ) @click.option( "--format", + "format_", type=click.Choice(["json", "grit"], case_sensitive=False), default="json", help="The output format for the migration script (json or grit).", ) def generic( - pkg1: str, pkg2: str, output: str, filter_by_all: bool, format: str + pkg1: str, + pkg2: str, + output: str, + filter_by_all: bool, # noqa: FBT001 + format_: str, ) -> None: """Generate a migration script.""" click.echo("Migration script generated.") @@ -62,9 +66,9 @@ def generic( name = f"{pkg1}_to_{pkg2}" if output is None: - output = f"{name}.json" if format == "json" else f"{name}.grit" + output = f"{name}.json" if format_ == "json" else f"{name}.grit" - if format == "json": + if format_ == "json": dumped = json.dumps(migrations, indent=2, sort_keys=True) else: dumped = dump_migrations_as_grit(name, migrations) diff --git a/libs/cli/tests/unit_tests/migrate/cli_runner/file.py b/libs/cli/tests/unit_tests/migrate/cli_runner/file.py index 5cf303095ed..1e0a5b4b497 100644 --- a/libs/cli/tests/unit_tests/migrate/cli_runner/file.py +++ b/libs/cli/tests/unit_tests/migrate/cli_runner/file.py @@ -6,7 +6,7 @@ class File: self.name = name self.content = "\n".join(content or []) - def __eq__(self, __value: object) -> bool: + def __eq__(self, __value: object, /) -> bool: if not isinstance(__value, File): return NotImplemented diff --git a/libs/cli/tests/unit_tests/migrate/cli_runner/folder.py b/libs/cli/tests/unit_tests/migrate/cli_runner/folder.py index 40d5d696e1c..d40d494d35c 100644 --- a/libs/cli/tests/unit_tests/migrate/cli_runner/folder.py +++ b/libs/cli/tests/unit_tests/migrate/cli_runner/folder.py @@ -34,12 +34,12 @@ class Folder: files.append(cls.from_structure(path)) else: files.append( - File(path.name, path.read_text(encoding="utf-8").splitlines()) + File(path.name, path.read_text(encoding="utf-8").splitlines()), ) return Folder(name, *files) - def __eq__(self, __value: object) -> bool: + def __eq__(self, __value: object, /) -> bool: if isinstance(__value, File): return False diff --git a/libs/cli/tests/unit_tests/migrate/cli_runner/test_cli.py b/libs/cli/tests/unit_tests/migrate/cli_runner/test_cli.py index e435547b16b..d53dc4b392c 100644 --- a/libs/cli/tests/unit_tests/migrate/cli_runner/test_cli.py +++ b/libs/cli/tests/unit_tests/migrate/cli_runner/test_cli.py @@ -34,7 +34,7 @@ def find_issue(current: Folder, expected: Folder) -> str: expected_file.content.splitlines(), fromfile=current_file.name, tofile=expected_file.name, - ) + ), ) return "Unknown" @@ -47,8 +47,10 @@ def test_command_line(tmp_path: Path) -> None: before.create_structure(root=Path(td)) # The input is used to force through the confirmation. result = runner.invoke(app, ["migrate", before.name, "--force"]) - assert result.exit_code == 0, result.output + if result.exit_code != 0: + raise RuntimeError(result.output) after = Folder.from_structure(Path(td) / before.name) - assert after == expected, find_issue(after, expected) + if after != expected: + raise ValueError(find_issue(after, expected)) diff --git a/libs/cli/tests/unit_tests/migrate/generate/test_langchain_migration.py b/libs/cli/tests/unit_tests/migrate/generate/test_langchain_migration.py index 1fd8a88161a..c566df21799 100644 --- a/libs/cli/tests/unit_tests/migrate/generate/test_langchain_migration.py +++ b/libs/cli/tests/unit_tests/migrate/generate/test_langchain_migration.py @@ -12,14 +12,15 @@ def test_create_json_agent_migration() -> None: """Test the migration of create_json_agent from langchain to langchain_community.""" with sup1(), sup2(): raw_migrations = generate_simplified_migrations( - from_package="langchain", to_package="langchain_community" + from_package="langchain", + to_package="langchain_community", ) json_agent_migrations = [ migration for migration in raw_migrations if "create_json_agent" in migration[0] ] - assert json_agent_migrations == [ + if json_agent_migrations != [ ( "langchain.agents.create_json_agent", "langchain_community.agent_toolkits.create_json_agent", @@ -32,7 +33,9 @@ def test_create_json_agent_migration() -> None: "langchain.agents.agent_toolkits.json.base.create_json_agent", "langchain_community.agent_toolkits.create_json_agent", ), - ] + ]: + msg = "json_agent_migrations did not match the expected value" + raise ValueError(msg) @pytest.mark.xfail(reason="Unknown reason") @@ -40,15 +43,21 @@ def test_create_single_store_retriever_db() -> None: """Test migration from langchain to langchain_core.""" with sup1(), sup2(): raw_migrations = generate_simplified_migrations( - from_package="langchain", to_package="langchain_core" + from_package="langchain", + to_package="langchain_core", ) # SingleStore was an old name for VectorStoreRetriever single_store_migration = [ migration for migration in raw_migrations if "SingleStore" in migration[0] ] - assert single_store_migration == [ + if single_store_migration != [ ( "langchain.vectorstores.singlestoredb.SingleStoreDBRetriever", "langchain_core.vectorstores.VectorStoreRetriever", ), - ] + ]: + msg = ( + "Unexpected migration: single_store_migration does not match expected " + "value" + ) + raise ValueError(msg) diff --git a/libs/cli/tests/unit_tests/migrate/generate/test_partner_migrations.py b/libs/cli/tests/unit_tests/migrate/generate/test_partner_migrations.py index a05386fe4a9..a636ce48e62 100644 --- a/libs/cli/tests/unit_tests/migrate/generate/test_partner_migrations.py +++ b/libs/cli/tests/unit_tests/migrate/generate/test_partner_migrations.py @@ -9,7 +9,7 @@ pytest.importorskip(modname="langchain_openai") def test_generate_migrations() -> None: migrations = get_migrations_for_partner_package("langchain_openai") - assert migrations == [ + if migrations != [ ("langchain_community.llms.openai.OpenAI", "langchain_openai.OpenAI"), ("langchain_community.llms.openai.AzureOpenAI", "langchain_openai.AzureOpenAI"), ( @@ -43,4 +43,6 @@ def test_generate_migrations() -> None: "langchain_openai.AzureChatOpenAI", ), ("langchain_community.chat_models.ChatOpenAI", "langchain_openai.ChatOpenAI"), - ] + ]: + msg = "Migrations do not match expected result" + raise ValueError(msg) diff --git a/libs/cli/tests/unit_tests/migrate/generate/test_utils.py b/libs/cli/tests/unit_tests/migrate/generate/test_utils.py index 38974ff79a9..cc73e3bf2de 100644 --- a/libs/cli/tests/unit_tests/migrate/generate/test_utils.py +++ b/libs/cli/tests/unit_tests/migrate/generate/test_utils.py @@ -2,4 +2,6 @@ from langchain_cli.namespaces.migrate.generate.utils import PKGS_ROOT def test_root() -> None: - assert PKGS_ROOT.name == "libs" + if PKGS_ROOT.name != "libs": + msg = "Expected PKGS_ROOT.name to be 'libs'." + raise ValueError(msg) diff --git a/libs/cli/tests/unit_tests/test_events.py b/libs/cli/tests/unit_tests/test_events.py index 7b5b3c6c1aa..52a793e1e5f 100644 --- a/libs/cli/tests/unit_tests/test_events.py +++ b/libs/cli/tests/unit_tests/test_events.py @@ -5,6 +5,7 @@ from langchain_cli.utils.events import EventDict, create_events @pytest.mark.xfail(reason="Unknown reason") def test_create_events() -> None: - assert create_events( - [EventDict(event="Test Event", properties={"test": "test"})] - ) == {"status": "success"} + result = create_events([EventDict(event="Test Event", properties={"test": "test"})]) + if result != {"status": "success"}: + msg = "Expected {'status': 'success'}, got " + repr(result) + raise ValueError(msg) diff --git a/libs/cli/tests/unit_tests/test_utils.py b/libs/cli/tests/unit_tests/test_utils.py index 424be057b5c..f85ab6bd9c0 100644 --- a/libs/cli/tests/unit_tests/test_utils.py +++ b/libs/cli/tests/unit_tests/test_utils.py @@ -18,17 +18,37 @@ def _assert_dependency_equals( subdirectory: Optional[str] = None, event_metadata: Optional[dict] = None, ) -> None: - assert dep["git"] == git - assert dep["ref"] == ref - assert dep["subdirectory"] == subdirectory - if event_metadata is not None: - assert dep["event_metadata"] == event_metadata + if dep["git"] != git: + msg = f"Expected git to be {git} but got {dep['git']}" + raise ValueError(msg) + if dep["ref"] != ref: + msg = f"Expected ref to be {ref} but got {dep['ref']}" + raise ValueError(msg) + if dep["subdirectory"] != subdirectory: + msg = ( + f"Expected subdirectory to be {subdirectory} but got {dep['subdirectory']}" + ) + raise ValueError(msg) + if dep["subdirectory"] != subdirectory: + msg = ( + f"Expected subdirectory to be {subdirectory} but got {dep['subdirectory']}" + ) + raise ValueError(msg) + if event_metadata is not None and dep["event_metadata"] != event_metadata: + msg = ( + f"Expected event_metadata to be {event_metadata} " + f"but got {dep['event_metadata']}" + ) + raise ValueError(msg) def test_dependency_string() -> None: _assert_dependency_equals( parse_dependency_string( - "git+ssh://git@github.com/efriis/myrepo.git", None, None, None + "git+ssh://git@github.com/efriis/myrepo.git", + None, + None, + None, ), git="ssh://git@github.com/efriis/myrepo.git", ref=None, @@ -49,7 +69,10 @@ def test_dependency_string() -> None: _assert_dependency_equals( parse_dependency_string( - "git+ssh://git@github.com:efriis/myrepo.git#develop", None, None, None + "git+ssh://git@github.com:efriis/myrepo.git#develop", + None, + None, + None, ), git="ssh://git@github.com:efriis/myrepo.git", ref="develop", @@ -59,7 +82,10 @@ def test_dependency_string() -> None: # also support a slash in ssh _assert_dependency_equals( parse_dependency_string( - "git+ssh://git@github.com/efriis/myrepo.git#develop", None, None, None + "git+ssh://git@github.com/efriis/myrepo.git#develop", + None, + None, + None, ), git="ssh://git@github.com/efriis/myrepo.git", ref="develop", @@ -69,7 +95,10 @@ def test_dependency_string() -> None: # looks like poetry supports both an @ and a # _assert_dependency_equals( parse_dependency_string( - "git+ssh://git@github.com:efriis/myrepo.git@develop", None, None, None + "git+ssh://git@github.com:efriis/myrepo.git@develop", + None, + None, + None, ), git="ssh://git@github.com:efriis/myrepo.git", ref="develop", @@ -100,7 +129,8 @@ def test_dependency_string_both() -> None: def test_dependency_string_invalids() -> None: # expect error for wrong order - with pytest.raises(ValueError): + # Bypassing validation since the ValueError message is dynamic + with pytest.raises(ValueError): # noqa: PT011 parse_dependency_string( "git+https://github.com/efriis/myrepo.git#subdirectory=src@branch", None, diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 06587170d4c..c30453595f3 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -26,9 +26,7 @@ readme = "README.md" repository = "https://github.com/langchain-ai/langchain" [dependency-groups] -lint = [ - "ruff<0.13.0,>=0.12.2", -] +lint = ["ruff<0.13,>=0.12.2"] typing = [ "mypy<1.16,>=1.15", "types-pyyaml<7.0.0.0,>=6.0.12.2", diff --git a/libs/langchain/langchain/__init__.py b/libs/langchain/langchain/__init__.py index 63de57ef4c2..ccae1437f68 100644 --- a/libs/langchain/langchain/__init__.py +++ b/libs/langchain/langchain/__init__.py @@ -28,11 +28,11 @@ def _warn_on_import(name: str, replacement: Optional[str] = None) -> None: if replacement: warnings.warn( f"Importing {name} from langchain root module is no longer supported. " - f"Please use {replacement} instead." + f"Please use {replacement} instead.", ) else: warnings.warn( - f"Importing {name} from langchain root module is no longer supported." + f"Importing {name} from langchain root module is no longer supported.", ) @@ -170,7 +170,8 @@ def __getattr__(name: str) -> Any: from langchain_community.llms import HuggingFaceTextGenInference _warn_on_import( - name, replacement="langchain_community.llms.HuggingFaceTextGenInference" + name, + replacement="langchain_community.llms.HuggingFaceTextGenInference", ) return HuggingFaceTextGenInference @@ -235,7 +236,8 @@ def __getattr__(name: str) -> Any: from langchain_core.prompts import FewShotPromptTemplate _warn_on_import( - name, replacement="langchain_core.prompts.FewShotPromptTemplate" + name, + replacement="langchain_core.prompts.FewShotPromptTemplate", ) return FewShotPromptTemplate @@ -263,7 +265,8 @@ def __getattr__(name: str) -> Any: from langchain_community.utilities import ArxivAPIWrapper _warn_on_import( - name, replacement="langchain_community.utilities.ArxivAPIWrapper" + name, + replacement="langchain_community.utilities.ArxivAPIWrapper", ) return ArxivAPIWrapper @@ -271,7 +274,8 @@ def __getattr__(name: str) -> Any: from langchain_community.utilities import GoldenQueryAPIWrapper _warn_on_import( - name, replacement="langchain_community.utilities.GoldenQueryAPIWrapper" + name, + replacement="langchain_community.utilities.GoldenQueryAPIWrapper", ) return GoldenQueryAPIWrapper @@ -279,7 +283,8 @@ def __getattr__(name: str) -> Any: from langchain_community.utilities import GoogleSearchAPIWrapper _warn_on_import( - name, replacement="langchain_community.utilities.GoogleSearchAPIWrapper" + name, + replacement="langchain_community.utilities.GoogleSearchAPIWrapper", ) return GoogleSearchAPIWrapper @@ -287,7 +292,8 @@ def __getattr__(name: str) -> Any: from langchain_community.utilities import GoogleSerperAPIWrapper _warn_on_import( - name, replacement="langchain_community.utilities.GoogleSerperAPIWrapper" + name, + replacement="langchain_community.utilities.GoogleSerperAPIWrapper", ) return GoogleSerperAPIWrapper @@ -295,7 +301,8 @@ def __getattr__(name: str) -> Any: from langchain_community.utilities import PowerBIDataset _warn_on_import( - name, replacement="langchain_community.utilities.PowerBIDataset" + name, + replacement="langchain_community.utilities.PowerBIDataset", ) return PowerBIDataset @@ -303,7 +310,8 @@ def __getattr__(name: str) -> Any: from langchain_community.utilities import SearxSearchWrapper _warn_on_import( - name, replacement="langchain_community.utilities.SearxSearchWrapper" + name, + replacement="langchain_community.utilities.SearxSearchWrapper", ) return SearxSearchWrapper @@ -311,7 +319,8 @@ def __getattr__(name: str) -> Any: from langchain_community.utilities import WikipediaAPIWrapper _warn_on_import( - name, replacement="langchain_community.utilities.WikipediaAPIWrapper" + name, + replacement="langchain_community.utilities.WikipediaAPIWrapper", ) return WikipediaAPIWrapper @@ -319,7 +328,8 @@ def __getattr__(name: str) -> Any: from langchain_community.utilities import WolframAlphaAPIWrapper _warn_on_import( - name, replacement="langchain_community.utilities.WolframAlphaAPIWrapper" + name, + replacement="langchain_community.utilities.WolframAlphaAPIWrapper", ) return WolframAlphaAPIWrapper @@ -339,7 +349,8 @@ def __getattr__(name: str) -> Any: from langchain_community.vectorstores import ElasticVectorSearch _warn_on_import( - name, replacement="langchain_community.vectorstores.ElasticVectorSearch" + name, + replacement="langchain_community.vectorstores.ElasticVectorSearch", ) return ElasticVectorSearch @@ -348,7 +359,8 @@ def __getattr__(name: str) -> Any: from langchain_community.utilities import SerpAPIWrapper _warn_on_import( - name, replacement="langchain_community.utilities.SerpAPIWrapper" + name, + replacement="langchain_community.utilities.SerpAPIWrapper", ) return SerpAPIWrapper diff --git a/libs/langchain/langchain/agents/__init__.py b/libs/langchain/langchain/agents/__init__.py index 73533226520..f919f5fbdf4 100644 --- a/libs/langchain/langchain/agents/__init__.py +++ b/libs/langchain/langchain/agents/__init__.py @@ -122,7 +122,9 @@ def __getattr__(name: str) -> Any: # Get directory of langchain package HERE = Path(__file__).parents[1] relative_path = as_import_path( - Path(__file__).parent, suffix=name, relative_to=HERE + Path(__file__).parent, + suffix=name, + relative_to=HERE, ) old_path = "langchain." + relative_path new_path = "langchain_experimental." + relative_path diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index f860d40361c..d51de7d89d2 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -135,7 +135,8 @@ class BaseSingleActionAgent(BaseModel): if early_stopping_method == "force": # `force` just returns a constant string return AgentFinish( - {"output": "Agent stopped due to iteration limit or time limit."}, "" + {"output": "Agent stopped due to iteration limit or time limit."}, + "", ) msg = f"Got unsupported early_stopping_method `{early_stopping_method}`" raise ValueError(msg) @@ -375,7 +376,7 @@ class AgentOutputParser(BaseOutputParser[Union[AgentAction, AgentFinish]]): class MultiActionAgentOutputParser( - BaseOutputParser[Union[list[AgentAction], AgentFinish]] + BaseOutputParser[Union[list[AgentAction], AgentFinish]], ): """Base class for parsing agent output into agent actions/finish. @@ -491,7 +492,8 @@ class RunnableAgent(BaseSingleActionAgent): # Because the response from the plan is not a generator, we need to # accumulate the output into final output and return that. async for chunk in self.runnable.astream( - inputs, config={"callbacks": callbacks} + inputs, + config={"callbacks": callbacks}, ): if final_output is None: final_output = chunk @@ -499,7 +501,8 @@ class RunnableAgent(BaseSingleActionAgent): final_output += chunk else: final_output = await self.runnable.ainvoke( - inputs, config={"callbacks": callbacks} + inputs, + config={"callbacks": callbacks}, ) return final_output @@ -607,7 +610,8 @@ class RunnableMultiActionAgent(BaseMultiActionAgent): # Because the response from the plan is not a generator, we need to # accumulate the output into final output and return that. async for chunk in self.runnable.astream( - inputs, config={"callbacks": callbacks} + inputs, + config={"callbacks": callbacks}, ): if final_output is None: final_output = chunk @@ -615,7 +619,8 @@ class RunnableMultiActionAgent(BaseMultiActionAgent): final_output += chunk else: final_output = await self.runnable.ainvoke( - inputs, config={"callbacks": callbacks} + inputs, + config={"callbacks": callbacks}, ) return final_output @@ -764,7 +769,8 @@ class Agent(BaseSingleActionAgent): ] def _construct_scratchpad( - self, intermediate_steps: list[tuple[AgentAction, str]] + self, + intermediate_steps: list[tuple[AgentAction, str]], ) -> Union[str, list[BaseMessage]]: """Construct the scratchpad that lets the agent continue its thought process.""" thoughts = "" @@ -816,7 +822,9 @@ class Agent(BaseSingleActionAgent): return await self.output_parser.aparse(full_output) def get_full_inputs( - self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: list[tuple[AgentAction, str]], + **kwargs: Any, ) -> builtins.dict[str, Any]: """Create the full inputs for the LLMChain from intermediate steps. @@ -858,7 +866,7 @@ class Agent(BaseSingleActionAgent): if "agent_scratchpad" not in prompt.input_variables: logger.warning( "`agent_scratchpad` should be a variable in prompt.input_variables." - " Did not find it, so adding it at the end." + " Did not find it, so adding it at the end.", ) prompt.input_variables.append("agent_scratchpad") if isinstance(prompt, PromptTemplate): @@ -964,7 +972,8 @@ class Agent(BaseSingleActionAgent): if early_stopping_method == "force": # `force` just returns a constant string return AgentFinish( - {"output": "Agent stopped due to iteration limit or time limit."}, "" + {"output": "Agent stopped due to iteration limit or time limit."}, + "", ) if early_stopping_method == "generate": # Generate does one final forward pass @@ -1072,7 +1081,8 @@ class AgentExecutor(Chain): as an observation. """ trim_intermediate_steps: Union[ - int, Callable[[list[tuple[AgentAction, str]]], list[tuple[AgentAction, str]]] + int, + Callable[[list[tuple[AgentAction, str]]], list[tuple[AgentAction, str]]], ] = -1 """How to trim the intermediate steps before returning them. Defaults to -1, which means no trimming. @@ -1153,11 +1163,13 @@ class AgentExecutor(Chain): stream_runnable = values.pop("stream_runnable", True) if multi_action: values["agent"] = RunnableMultiActionAgent( - runnable=agent, stream_runnable=stream_runnable + runnable=agent, + stream_runnable=stream_runnable, ) else: values["agent"] = RunnableAgent( - runnable=agent, stream_runnable=stream_runnable + runnable=agent, + stream_runnable=stream_runnable, ) return values @@ -1281,7 +1293,9 @@ class AgentExecutor(Chain): ) -> dict[str, Any]: if run_manager: await run_manager.on_agent_finish( - output, color="green", verbose=self.verbose + output, + color="green", + verbose=self.verbose, ) final_output = output.return_values if self.return_intermediate_steps: @@ -1289,7 +1303,8 @@ class AgentExecutor(Chain): return final_output def _consume_next_step( - self, values: NextStepOutput + self, + values: NextStepOutput, ) -> Union[AgentFinish, list[tuple[AgentAction, str]]]: if isinstance(values[-1], AgentFinish): if len(values) != 1: @@ -1314,8 +1329,8 @@ class AgentExecutor(Chain): inputs, intermediate_steps, run_manager, - ) - ) + ), + ), ) def _iter_next_step( @@ -1391,7 +1406,10 @@ class AgentExecutor(Chain): yield agent_action for agent_action in actions: yield self._perform_agent_action( - name_to_tool_map, color_mapping, agent_action, run_manager + name_to_tool_map, + color_mapping, + agent_action, + run_manager, ) def _perform_agent_action( @@ -1451,7 +1469,7 @@ class AgentExecutor(Chain): intermediate_steps, run_manager, ) - ] + ], ) async def _aiter_next_step( @@ -1528,7 +1546,10 @@ class AgentExecutor(Chain): result = await asyncio.gather( *[ self._aperform_agent_action( - name_to_tool_map, color_mapping, agent_action, run_manager + name_to_tool_map, + color_mapping, + agent_action, + run_manager, ) for agent_action in actions ], @@ -1547,7 +1568,9 @@ class AgentExecutor(Chain): ) -> AgentStep: if run_manager: await run_manager.on_agent_action( - agent_action, verbose=self.verbose, color="green" + agent_action, + verbose=self.verbose, + color="green", ) # Otherwise we lookup the tool if agent_action.tool in name_to_tool_map: @@ -1589,7 +1612,8 @@ class AgentExecutor(Chain): name_to_tool_map = {tool.name: tool for tool in self.tools} # We construct a mapping from each tool to a color, used for logging. color_mapping = get_color_mapping( - [tool.name for tool in self.tools], excluded_colors=["green", "red"] + [tool.name for tool in self.tools], + excluded_colors=["green", "red"], ) intermediate_steps: list[tuple[AgentAction, str]] = [] # Let's start tracking the number of iterations and time elapsed @@ -1607,7 +1631,9 @@ class AgentExecutor(Chain): ) if isinstance(next_step_output, AgentFinish): return self._return( - next_step_output, intermediate_steps, run_manager=run_manager + next_step_output, + intermediate_steps, + run_manager=run_manager, ) intermediate_steps.extend(next_step_output) @@ -1617,12 +1643,16 @@ class AgentExecutor(Chain): tool_return = self._get_tool_return(next_step_action) if tool_return is not None: return self._return( - tool_return, intermediate_steps, run_manager=run_manager + tool_return, + intermediate_steps, + run_manager=run_manager, ) iterations += 1 time_elapsed = time.time() - start_time output = self._action_agent.return_stopped_response( - self.early_stopping_method, intermediate_steps, **inputs + self.early_stopping_method, + intermediate_steps, + **inputs, ) return self._return(output, intermediate_steps, run_manager=run_manager) @@ -1636,7 +1666,8 @@ class AgentExecutor(Chain): name_to_tool_map = {tool.name: tool for tool in self.tools} # We construct a mapping from each tool to a color, used for logging. color_mapping = get_color_mapping( - [tool.name for tool in self.tools], excluded_colors=["green"] + [tool.name for tool in self.tools], + excluded_colors=["green"], ) intermediate_steps: list[tuple[AgentAction, str]] = [] # Let's start tracking the number of iterations and time elapsed @@ -1668,28 +1699,39 @@ class AgentExecutor(Chain): tool_return = self._get_tool_return(next_step_action) if tool_return is not None: return await self._areturn( - tool_return, intermediate_steps, run_manager=run_manager + tool_return, + intermediate_steps, + run_manager=run_manager, ) iterations += 1 time_elapsed = time.time() - start_time output = self._action_agent.return_stopped_response( - self.early_stopping_method, intermediate_steps, **inputs + self.early_stopping_method, + intermediate_steps, + **inputs, ) return await self._areturn( - output, intermediate_steps, run_manager=run_manager + output, + intermediate_steps, + run_manager=run_manager, ) except (TimeoutError, asyncio.TimeoutError): # stop early when interrupted by the async timeout output = self._action_agent.return_stopped_response( - self.early_stopping_method, intermediate_steps, **inputs + self.early_stopping_method, + intermediate_steps, + **inputs, ) return await self._areturn( - output, intermediate_steps, run_manager=run_manager + output, + intermediate_steps, + run_manager=run_manager, ) def _get_tool_return( - self, next_step_output: tuple[AgentAction, str] + self, + next_step_output: tuple[AgentAction, str], ) -> Optional[AgentFinish]: """Check if the tool is a returning tool.""" agent_action, observation = next_step_output @@ -1709,7 +1751,8 @@ class AgentExecutor(Chain): return None def _prepare_intermediate_steps( - self, intermediate_steps: list[tuple[AgentAction, str]] + self, + intermediate_steps: list[tuple[AgentAction, str]], ) -> list[tuple[AgentAction, str]]: if ( isinstance(self.trim_intermediate_steps, int) diff --git a/libs/langchain/langchain/agents/agent_iterator.py b/libs/langchain/langchain/agents/agent_iterator.py index 2e07b29805c..f249d7d1153 100644 --- a/libs/langchain/langchain/agents/agent_iterator.py +++ b/libs/langchain/langchain/agents/agent_iterator.py @@ -147,7 +147,7 @@ class AgentExecutorIterator: self.iterations += 1 self.time_elapsed = time.time() - self.start_time logger.debug( - f"Agent Iterations: {self.iterations} ({self.time_elapsed:.2f}s elapsed)" + f"Agent Iterations: {self.iterations} ({self.time_elapsed:.2f}s elapsed)", ) def make_final_outputs( @@ -160,8 +160,10 @@ class AgentExecutorIterator: prepared_outputs = AddableDict( self.agent_executor.prep_outputs( - self.inputs, outputs, return_only_outputs=True - ) + self.inputs, + outputs, + return_only_outputs=True, + ), ) if self.include_run_info: prepared_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) @@ -187,7 +189,8 @@ class AgentExecutorIterator: ) try: while self.agent_executor._should_continue( - self.iterations, self.time_elapsed + self.iterations, + self.time_elapsed, ): # take the next step: this plans next action, executes it, # yielding action and observation as they are generated @@ -254,7 +257,8 @@ class AgentExecutorIterator: try: async with asyncio_timeout(self.agent_executor.max_execution_time): while self.agent_executor._should_continue( - self.iterations, self.time_elapsed + self.iterations, + self.time_elapsed, ): # take the next step: this plans next action, executes it, # yielding action and observation as they are generated @@ -272,11 +276,13 @@ class AgentExecutorIterator: if self.yield_actions: if isinstance(chunk, AgentAction): yield AddableDict( - actions=[chunk], messages=chunk.messages + actions=[chunk], + messages=chunk.messages, ) elif isinstance(chunk, AgentStep): yield AddableDict( - steps=[chunk], messages=chunk.messages + steps=[chunk], + messages=chunk.messages, ) # convert iterator output to format handled by _process_next_step @@ -285,7 +291,8 @@ class AgentExecutorIterator: self.update_iterations() # decide if this is the final output output = await self._aprocess_next_step_output( - next_step, run_manager + next_step, + run_manager, ) is_final = "intermediate_step" not in output # yield the final output always @@ -317,7 +324,7 @@ class AgentExecutorIterator: logger.debug("Processing output of Agent loop step") if isinstance(next_step_output, AgentFinish): logger.debug( - "Hit AgentFinish: _return -> on_chain_end -> run final output logic" + "Hit AgentFinish: _return -> on_chain_end -> run final output logic", ) return self._return(next_step_output, run_manager=run_manager) @@ -345,7 +352,7 @@ class AgentExecutorIterator: logger.debug("Processing output of async Agent loop step") if isinstance(next_step_output, AgentFinish): logger.debug( - "Hit AgentFinish: _areturn -> on_chain_end -> run final output logic" + "Hit AgentFinish: _areturn -> on_chain_end -> run final output logic", ) return await self._areturn(next_step_output, run_manager=run_manager) @@ -388,26 +395,34 @@ class AgentExecutorIterator: return await self._areturn(output, run_manager=run_manager) def _return( - self, output: AgentFinish, run_manager: CallbackManagerForChainRun + self, + output: AgentFinish, + run_manager: CallbackManagerForChainRun, ) -> AddableDict: """ Return the final output of the iterator. """ returned_output = self.agent_executor._return( - output, self.intermediate_steps, run_manager=run_manager + output, + self.intermediate_steps, + run_manager=run_manager, ) returned_output["messages"] = output.messages run_manager.on_chain_end(returned_output) return self.make_final_outputs(returned_output, run_manager) async def _areturn( - self, output: AgentFinish, run_manager: AsyncCallbackManagerForChainRun + self, + output: AgentFinish, + run_manager: AsyncCallbackManagerForChainRun, ) -> AddableDict: """ Return the final output of the async iterator. """ returned_output = await self.agent_executor._areturn( - output, self.intermediate_steps, run_manager=run_manager + output, + self.intermediate_steps, + run_manager=run_manager, ) returned_output["messages"] = output.messages await run_manager.on_chain_end(returned_output) diff --git a/libs/langchain/langchain/agents/agent_toolkits/ainetwork/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/ainetwork/toolkit.py index 2a797a0bfc8..777eab90e3e 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/ainetwork/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/ainetwork/toolkit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "AINetworkToolkit": "langchain_community.agent_toolkits.ainetwork.toolkit" + "AINetworkToolkit": "langchain_community.agent_toolkits.ainetwork.toolkit", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/amadeus/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/amadeus/toolkit.py index e7e0b93d9ed..418784bc23f 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/amadeus/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/amadeus/toolkit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "AmadeusToolkit": "langchain_community.agent_toolkits.amadeus.toolkit" + "AmadeusToolkit": "langchain_community.agent_toolkits.amadeus.toolkit", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/azure_cognitive_services.py b/libs/langchain/langchain/agents/agent_toolkits/azure_cognitive_services.py index 4f49ae86e16..d4bc8df0dc1 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/azure_cognitive_services.py +++ b/libs/langchain/langchain/agents/agent_toolkits/azure_cognitive_services.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: DEPRECATED_LOOKUP = { "AzureCognitiveServicesToolkit": ( "langchain_community.agent_toolkits.azure_cognitive_services" - ) + ), } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/clickup/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/clickup/toolkit.py index e58a8678cdc..59268cc7a1e 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/clickup/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/clickup/toolkit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "ClickupToolkit": "langchain_community.agent_toolkits.clickup.toolkit" + "ClickupToolkit": "langchain_community.agent_toolkits.clickup.toolkit", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/conversational_retrieval/openai_functions.py b/libs/langchain/langchain/agents/agent_toolkits/conversational_retrieval/openai_functions.py index 2eeefc50864..cd970ba17ae 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/conversational_retrieval/openai_functions.py +++ b/libs/langchain/langchain/agents/agent_toolkits/conversational_retrieval/openai_functions.py @@ -20,7 +20,7 @@ def _get_default_system_message() -> SystemMessage: "Do your best to answer the questions. " "Feel free to use any tools available to look up " "relevant information, only if necessary" - ) + ), ) @@ -58,7 +58,9 @@ def create_conversational_retrieval_agent( if remember_intermediate_steps: memory: BaseMemory = AgentTokenBufferMemory( - memory_key=memory_key, llm=llm, max_token_limit=max_token_limit + memory_key=memory_key, + llm=llm, + max_token_limit=max_token_limit, ) else: memory = ConversationTokenBufferMemory( diff --git a/libs/langchain/langchain/agents/agent_toolkits/file_management/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/file_management/__init__.py index 4210b88897e..c90890f967b 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/file_management/__init__.py +++ b/libs/langchain/langchain/agents/agent_toolkits/file_management/__init__.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: DEPRECATED_LOOKUP = { "FileManagementToolkit": ( "langchain_community.agent_toolkits.file_management.toolkit" - ) + ), } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/file_management/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/file_management/toolkit.py index ac6830e14a7..791a44050fa 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/file_management/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/file_management/toolkit.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: DEPRECATED_LOOKUP = { "FileManagementToolkit": ( "langchain_community.agent_toolkits.file_management.toolkit" - ) + ), } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/gitlab/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/gitlab/toolkit.py index c603fc456d2..55a1746523b 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/gitlab/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/gitlab/toolkit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "GitLabToolkit": "langchain_community.agent_toolkits.gitlab.toolkit" + "GitLabToolkit": "langchain_community.agent_toolkits.gitlab.toolkit", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/json/base.py b/libs/langchain/langchain/agents/agent_toolkits/json/base.py index f60163265c5..faf3135c763 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/json/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/json/base.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "create_json_agent": "langchain_community.agent_toolkits.json.base" + "create_json_agent": "langchain_community.agent_toolkits.json.base", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/multion/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/multion/toolkit.py index eee04a9e6f2..3b2160fa04d 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/multion/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/multion/toolkit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "MultionToolkit": "langchain_community.agent_toolkits.multion.toolkit" + "MultionToolkit": "langchain_community.agent_toolkits.multion.toolkit", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/office365/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/office365/toolkit.py index 466d979d0eb..b7f1480f98a 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/office365/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/office365/toolkit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "O365Toolkit": "langchain_community.agent_toolkits.office365.toolkit" + "O365Toolkit": "langchain_community.agent_toolkits.office365.toolkit", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/openapi/base.py b/libs/langchain/langchain/agents/agent_toolkits/openapi/base.py index acb6eaba639..65eb71141e3 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/openapi/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/openapi/base.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "create_openapi_agent": "langchain_community.agent_toolkits.openapi.base" + "create_openapi_agent": "langchain_community.agent_toolkits.openapi.base", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/playwright/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/playwright/__init__.py index 169766c11c6..7a9762fa853 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/playwright/__init__.py +++ b/libs/langchain/langchain/agents/agent_toolkits/playwright/__init__.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "PlayWrightBrowserToolkit": "langchain_community.agent_toolkits.playwright.toolkit" + "PlayWrightBrowserToolkit": "langchain_community.agent_toolkits.playwright.toolkit", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/playwright/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/playwright/toolkit.py index debd504ace4..2bb70a0551c 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/playwright/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/playwright/toolkit.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "PlayWrightBrowserToolkit": "langchain_community.agent_toolkits.playwright.toolkit" + "PlayWrightBrowserToolkit": "langchain_community.agent_toolkits.playwright.toolkit", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/powerbi/base.py b/libs/langchain/langchain/agents/agent_toolkits/powerbi/base.py index 18538bbcb30..e06724ad5f0 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/powerbi/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/powerbi/base.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "create_pbi_agent": "langchain_community.agent_toolkits.powerbi.base" + "create_pbi_agent": "langchain_community.agent_toolkits.powerbi.base", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/powerbi/chat_base.py b/libs/langchain/langchain/agents/agent_toolkits/powerbi/chat_base.py index db0bed9bdcf..75d44d4bbd1 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/powerbi/chat_base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/powerbi/chat_base.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "create_pbi_chat_agent": "langchain_community.agent_toolkits.powerbi.chat_base" + "create_pbi_chat_agent": "langchain_community.agent_toolkits.powerbi.chat_base", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/powerbi/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/powerbi/toolkit.py index b008c385e76..55dafc93b5f 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/powerbi/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/powerbi/toolkit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "PowerBIToolkit": "langchain_community.agent_toolkits.powerbi.toolkit" + "PowerBIToolkit": "langchain_community.agent_toolkits.powerbi.toolkit", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/spark_sql/base.py b/libs/langchain/langchain/agents/agent_toolkits/spark_sql/base.py index 25b3bda1aea..d5443ffa55e 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/spark_sql/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/spark_sql/base.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "create_spark_sql_agent": "langchain_community.agent_toolkits.spark_sql.base" + "create_spark_sql_agent": "langchain_community.agent_toolkits.spark_sql.base", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/spark_sql/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/spark_sql/toolkit.py index 4c393c02f01..c75361369c0 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/spark_sql/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/spark_sql/toolkit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "SparkSQLToolkit": "langchain_community.agent_toolkits.spark_sql.toolkit" + "SparkSQLToolkit": "langchain_community.agent_toolkits.spark_sql.toolkit", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/sql/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/sql/toolkit.py index e330c4894c3..92ee88123e9 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/sql/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/sql/toolkit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "SQLDatabaseToolkit": "langchain_community.agent_toolkits.sql.toolkit" + "SQLDatabaseToolkit": "langchain_community.agent_toolkits.sql.toolkit", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/prompt.py b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/prompt.py index a2837e56f18..ec992522ebc 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/prompt.py +++ b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/prompt.py @@ -1,13 +1,11 @@ -# flake8: noqa - PREFIX = """You are an agent designed to answer questions about sets of documents. You have access to tools for interacting with the documents, and the inputs to the tools are questions. Sometimes, you will be asked to provide sources for your questions, in which case you should use the appropriate tool to do so. If the question does not seem relevant to any of the tools provided, just return "I don't know" as the answer. -""" +""" # noqa: E501 ROUTER_PREFIX = """You are an agent designed to answer questions. You have access to tools for interacting with different sources, and the inputs to the tools are questions. Your main task is to decide which of the tools is relevant for answering question at hand. For complex questions, you can break the question down into sub questions and use tools to answers the sub questions. -""" +""" # noqa: E501 diff --git a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py index bcd9bdf31df..3f4127f5b89 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py @@ -40,7 +40,8 @@ class VectorStoreToolkit(BaseToolkit): msg = "You need to install langchain-community to use this toolkit." raise ImportError(msg) description = VectorStoreQATool.get_description( - self.vectorstore_info.name, self.vectorstore_info.description + self.vectorstore_info.name, + self.vectorstore_info.description, ) qa_tool = VectorStoreQATool( name=self.vectorstore_info.name, @@ -49,7 +50,8 @@ class VectorStoreToolkit(BaseToolkit): llm=self.llm, ) description = VectorStoreQAWithSourcesTool.get_description( - self.vectorstore_info.name, self.vectorstore_info.description + self.vectorstore_info.name, + self.vectorstore_info.description, ) qa_with_sources_tool = VectorStoreQAWithSourcesTool( name=f"{self.vectorstore_info.name}_with_sources", @@ -82,7 +84,8 @@ class VectorStoreRouterToolkit(BaseToolkit): raise ImportError(msg) for vectorstore_info in self.vectorstores: description = VectorStoreQATool.get_description( - vectorstore_info.name, vectorstore_info.description + vectorstore_info.name, + vectorstore_info.description, ) qa_tool = VectorStoreQATool( name=vectorstore_info.name, diff --git a/libs/langchain/langchain/agents/agent_toolkits/zapier/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/zapier/toolkit.py index 83461b54a23..94ae20b6759 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/zapier/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/zapier/toolkit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "ZapierToolkit": "langchain_community.agent_toolkits.zapier.toolkit" + "ZapierToolkit": "langchain_community.agent_toolkits.zapier.toolkit", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/agents/chat/base.py b/libs/langchain/langchain/agents/chat/base.py index 3409e1b607d..6b5c05bed26 100644 --- a/libs/langchain/langchain/agents/chat/base.py +++ b/libs/langchain/langchain/agents/chat/base.py @@ -49,7 +49,8 @@ class ChatAgent(Agent): return "Thought:" def _construct_scratchpad( - self, intermediate_steps: list[tuple[AgentAction, str]] + self, + intermediate_steps: list[tuple[AgentAction, str]], ) -> str: agent_scratchpad = super()._construct_scratchpad(intermediate_steps) if not isinstance(agent_scratchpad, str): @@ -106,14 +107,7 @@ class ChatAgent(Agent): tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) tool_names = ", ".join([tool.name for tool in tools]) format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join( - [ - system_message_prefix, - tool_strings, - format_instructions, - system_message_suffix, - ] - ) + template = f"{system_message_prefix}\n\n{tool_strings}\n\n{format_instructions}\n\n{system_message_suffix}" # noqa: E501 messages = [ SystemMessagePromptTemplate.from_template(template), HumanMessagePromptTemplate.from_template(human_message), diff --git a/libs/langchain/langchain/agents/chat/output_parser.py b/libs/langchain/langchain/agents/chat/output_parser.py index 21066c6627b..69c71ed32c8 100644 --- a/libs/langchain/langchain/agents/chat/output_parser.py +++ b/libs/langchain/langchain/agents/chat/output_parser.py @@ -57,7 +57,9 @@ class ChatOutputParser(AgentOutputParser): ) raise OutputParserException(msg) return AgentAction( - response["action"], response.get("action_input", {}), text + response["action"], + response.get("action_input", {}), + text, ) except Exception as exc: diff --git a/libs/langchain/langchain/agents/chat/prompt.py b/libs/langchain/langchain/agents/chat/prompt.py index 4343739bf41..a5e7e7b1e5b 100644 --- a/libs/langchain/langchain/agents/chat/prompt.py +++ b/libs/langchain/langchain/agents/chat/prompt.py @@ -1,5 +1,4 @@ -# flake8: noqa -SYSTEM_MESSAGE_PREFIX = """Answer the following questions as best you can. You have access to the following tools:""" +SYSTEM_MESSAGE_PREFIX = """Answer the following questions as best you can. You have access to the following tools:""" # noqa: E501 FORMAT_INSTRUCTIONS = """The way you use the tools is by specifying a json blob. Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here). @@ -25,6 +24,6 @@ $JSON_BLOB Observation: the result of the action ... (this Thought/Action/Observation can repeat N times) Thought: I now know the final answer -Final Answer: the final answer to the original input question""" -SYSTEM_MESSAGE_SUFFIX = """Begin! Reminder to always use the exact characters `Final Answer` when responding.""" +Final Answer: the final answer to the original input question""" # noqa: E501 +SYSTEM_MESSAGE_SUFFIX = """Begin! Reminder to always use the exact characters `Final Answer` when responding.""" # noqa: E501 HUMAN_MESSAGE = "{input}\n\n{agent_scratchpad}" diff --git a/libs/langchain/langchain/agents/conversational/base.py b/libs/langchain/langchain/agents/conversational/base.py index 32846b0c85f..71aa243c55c 100644 --- a/libs/langchain/langchain/agents/conversational/base.py +++ b/libs/langchain/langchain/agents/conversational/base.py @@ -36,7 +36,9 @@ class ConversationalAgent(Agent): @classmethod def _get_default_output_parser( - cls, ai_prefix: str = "AI", **kwargs: Any + cls, + ai_prefix: str = "AI", + **kwargs: Any, ) -> AgentOutputParser: return ConvoOutputParser(ai_prefix=ai_prefix) @@ -93,13 +95,15 @@ class ConversationalAgent(Agent): A PromptTemplate with the template assembled from the pieces here. """ tool_strings = "\n".join( - [f"> {tool.name}: {tool.description}" for tool in tools] + [f"> {tool.name}: {tool.description}" for tool in tools], ) tool_names = ", ".join([tool.name for tool in tools]) format_instructions = format_instructions.format( - tool_names=tool_names, ai_prefix=ai_prefix, human_prefix=human_prefix + tool_names=tool_names, + ai_prefix=ai_prefix, + human_prefix=human_prefix, ) - template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) + template = f"{prefix}\n\n{tool_strings}\n\n{format_instructions}\n\n{suffix}" if input_variables is None: input_variables = ["input", "chat_history", "agent_scratchpad"] return PromptTemplate(template=template, input_variables=input_variables) @@ -161,7 +165,7 @@ class ConversationalAgent(Agent): ) tool_names = [tool.name for tool in tools] _output_parser = output_parser or cls._get_default_output_parser( - ai_prefix=ai_prefix + ai_prefix=ai_prefix, ) return cls( llm_chain=llm_chain, diff --git a/libs/langchain/langchain/agents/conversational/output_parser.py b/libs/langchain/langchain/agents/conversational/output_parser.py index 82f098162a9..695c580849b 100644 --- a/libs/langchain/langchain/agents/conversational/output_parser.py +++ b/libs/langchain/langchain/agents/conversational/output_parser.py @@ -34,7 +34,8 @@ class ConvoOutputParser(AgentOutputParser): if f"{self.ai_prefix}:" in text: return AgentFinish( - {"output": text.split(f"{self.ai_prefix}:")[-1].strip()}, text + {"output": text.split(f"{self.ai_prefix}:")[-1].strip()}, + text, ) regex = r"Action: (.*?)[\n]*Action Input: ([\s\S]*)" match = re.search(regex, text, re.DOTALL) diff --git a/libs/langchain/langchain/agents/conversational/prompt.py b/libs/langchain/langchain/agents/conversational/prompt.py index 15268a76083..26938e32fa2 100644 --- a/libs/langchain/langchain/agents/conversational/prompt.py +++ b/libs/langchain/langchain/agents/conversational/prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa PREFIX = """Assistant is a large language model trained by OpenAI. Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. @@ -10,7 +9,7 @@ Overall, Assistant is a powerful tool that can help with a wide range of tasks a TOOLS: ------ -Assistant has access to the following tools:""" +Assistant has access to the following tools:""" # noqa: E501 FORMAT_INSTRUCTIONS = """To use a tool, please use the following format: ``` @@ -25,7 +24,7 @@ When you have a response to say to the Human, or if you do not need to use a too ``` Thought: Do I need to use a tool? No {ai_prefix}: [your response here] -```""" +```""" # noqa: E501 SUFFIX = """Begin! diff --git a/libs/langchain/langchain/agents/conversational_chat/base.py b/libs/langchain/langchain/agents/conversational_chat/base.py index 35bf4bffbd9..e814cc53784 100644 --- a/libs/langchain/langchain/agents/conversational_chat/base.py +++ b/libs/langchain/langchain/agents/conversational_chat/base.py @@ -96,15 +96,16 @@ class ConversationalChatAgent(Agent): A PromptTemplate. """ tool_strings = "\n".join( - [f"> {tool.name}: {tool.description}" for tool in tools] + [f"> {tool.name}: {tool.description}" for tool in tools], ) tool_names = ", ".join([tool.name for tool in tools]) _output_parser = output_parser or cls._get_default_output_parser() format_instructions = human_message.format( - format_instructions=_output_parser.get_format_instructions() + format_instructions=_output_parser.get_format_instructions(), ) final_prompt = format_instructions.format( - tool_names=tool_names, tools=tool_strings + tool_names=tool_names, + tools=tool_strings, ) if input_variables is None: input_variables = ["input", "chat_history", "agent_scratchpad"] @@ -117,14 +118,15 @@ class ConversationalChatAgent(Agent): return ChatPromptTemplate(input_variables=input_variables, messages=messages) def _construct_scratchpad( - self, intermediate_steps: list[tuple[AgentAction, str]] + self, + intermediate_steps: list[tuple[AgentAction, str]], ) -> list[BaseMessage]: """Construct the scratchpad that lets the agent continue its thought process.""" thoughts: list[BaseMessage] = [] for action, observation in intermediate_steps: thoughts.append(AIMessage(content=action.log)) human_message = HumanMessage( - content=self.template_tool_response.format(observation=observation) + content=self.template_tool_response.format(observation=observation), ) thoughts.append(human_message) return thoughts diff --git a/libs/langchain/langchain/agents/conversational_chat/prompt.py b/libs/langchain/langchain/agents/conversational_chat/prompt.py index 712a9238eae..153cbbf9297 100644 --- a/libs/langchain/langchain/agents/conversational_chat/prompt.py +++ b/libs/langchain/langchain/agents/conversational_chat/prompt.py @@ -1,11 +1,10 @@ -# flake8: noqa PREFIX = """Assistant is a large language model trained by OpenAI. Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. Assistant is constantly learning and improving, and its capabilities are constantly evolving. It is able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. Additionally, Assistant is able to generate its own text based on the input it receives, allowing it to engage in discussions and provide explanations and descriptions on a wide range of topics. -Overall, Assistant is a powerful system that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, Assistant is here to assist.""" +Overall, Assistant is a powerful system that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, Assistant is here to assist.""" # noqa: E501 FORMAT_INSTRUCTIONS = """RESPONSE FORMAT INSTRUCTIONS ---------------------------- @@ -31,7 +30,7 @@ Use this if you want to respond directly to the human. Markdown code snippet for "action": "Final Answer", "action_input": string \\\\ You should put what you want to return to use here }}}} -```""" +```""" # noqa: E501 SUFFIX = """TOOLS ------ @@ -45,13 +44,13 @@ USER'S INPUT -------------------- Here is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else): -{{{{input}}}}""" +{{{{input}}}}""" # noqa: E501 -TEMPLATE_TOOL_RESPONSE = """TOOL RESPONSE: +TEMPLATE_TOOL_RESPONSE = """TOOL RESPONSE: --------------------- {observation} USER'S INPUT -------------------- -Okay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else.""" +Okay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else.""" # noqa: E501 diff --git a/libs/langchain/langchain/agents/format_scratchpad/log_to_messages.py b/libs/langchain/langchain/agents/format_scratchpad/log_to_messages.py index 98c5d04ee83..a193e37ae38 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/log_to_messages.py +++ b/libs/langchain/langchain/agents/format_scratchpad/log_to_messages.py @@ -20,7 +20,7 @@ def format_log_to_messages( for action, observation in intermediate_steps: thoughts.append(AIMessage(content=action.log)) human_message = HumanMessage( - content=template_tool_response.format(observation=observation) + content=template_tool_response.format(observation=observation), ) thoughts.append(human_message) return thoughts diff --git a/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py b/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py index 3055cc05545..188fd0a07ec 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py +++ b/libs/langchain/langchain/agents/format_scratchpad/openai_functions.py @@ -6,7 +6,8 @@ from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage def _convert_agent_action_to_messages( - agent_action: AgentAction, observation: str + agent_action: AgentAction, + observation: str, ) -> list[BaseMessage]: """Convert an agent action to a message. @@ -28,7 +29,8 @@ def _convert_agent_action_to_messages( def _create_function_message( - agent_action: AgentAction, observation: str + agent_action: AgentAction, + observation: str, ) -> FunctionMessage: """Convert agent action and observation into a function message. Args: diff --git a/libs/langchain/langchain/agents/format_scratchpad/tools.py b/libs/langchain/langchain/agents/format_scratchpad/tools.py index 422bafdf605..6414b059f2a 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/tools.py +++ b/libs/langchain/langchain/agents/format_scratchpad/tools.py @@ -12,7 +12,8 @@ from langchain.agents.output_parsers.tools import ToolAgentAction def _create_tool_message( - agent_action: ToolAgentAction, observation: str + agent_action: ToolAgentAction, + observation: str, ) -> ToolMessage: """Convert agent action and observation into a tool message. diff --git a/libs/langchain/langchain/agents/initialize.py b/libs/langchain/langchain/agents/initialize.py index b606c942e4c..38717e89c63 100644 --- a/libs/langchain/langchain/agents/initialize.py +++ b/libs/langchain/langchain/agents/initialize.py @@ -76,11 +76,17 @@ def initialize_agent( agent_cls = AGENT_TO_CLASS[agent] agent_kwargs = agent_kwargs or {} agent_obj = agent_cls.from_llm_and_tools( - llm, tools, callback_manager=callback_manager, **agent_kwargs + llm, + tools, + callback_manager=callback_manager, + **agent_kwargs, ) elif agent_path is not None: agent_obj = load_agent( - agent_path, llm=llm, tools=tools, callback_manager=callback_manager + agent_path, + llm=llm, + tools=tools, + callback_manager=callback_manager, ) with contextlib.suppress(NotImplementedError): # TODO: Add tags from the serialized object directly. diff --git a/libs/langchain/langchain/agents/json_chat/base.py b/libs/langchain/langchain/agents/json_chat/base.py index 8af12bf79c4..67ded18ab21 100644 --- a/libs/langchain/langchain/agents/json_chat/base.py +++ b/libs/langchain/langchain/agents/json_chat/base.py @@ -162,7 +162,7 @@ def create_json_chat_agent( ) """ # noqa: E501 missing_vars = {"tools", "tool_names", "agent_scratchpad"}.difference( - prompt.input_variables + list(prompt.partial_variables) + prompt.input_variables + list(prompt.partial_variables), ) if missing_vars: msg = f"Prompt missing required variables: {missing_vars}" @@ -185,8 +185,9 @@ def create_json_chat_agent( return ( RunnablePassthrough.assign( agent_scratchpad=lambda x: format_log_to_messages( - x["intermediate_steps"], template_tool_response=template_tool_response - ) + x["intermediate_steps"], + template_tool_response=template_tool_response, + ), ) | prompt | llm_to_use diff --git a/libs/langchain/langchain/agents/json_chat/prompt.py b/libs/langchain/langchain/agents/json_chat/prompt.py index 34020caa29f..dc4a701071c 100644 --- a/libs/langchain/langchain/agents/json_chat/prompt.py +++ b/libs/langchain/langchain/agents/json_chat/prompt.py @@ -1,9 +1,8 @@ -# flake8: noqa -TEMPLATE_TOOL_RESPONSE = """TOOL RESPONSE: +TEMPLATE_TOOL_RESPONSE = """TOOL RESPONSE: --------------------- {observation} USER'S INPUT -------------------- -Okay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else - even if you just want to respond to the user. Do NOT respond with anything except a JSON snippet no matter what!""" +Okay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else - even if you just want to respond to the user. Do NOT respond with anything except a JSON snippet no matter what!""" # noqa: E501 diff --git a/libs/langchain/langchain/agents/load_tools.py b/libs/langchain/langchain/agents/load_tools.py index dccc67d8dd1..923fbd1e90d 100644 --- a/libs/langchain/langchain/agents/load_tools.py +++ b/libs/langchain/langchain/agents/load_tools.py @@ -3,7 +3,8 @@ from typing import Any from langchain._api import create_importer _importer = create_importer( - __package__, fallback_module="langchain_community.agent_toolkits.load_tools" + __package__, + fallback_module="langchain_community.agent_toolkits.load_tools", ) diff --git a/libs/langchain/langchain/agents/loading.py b/libs/langchain/langchain/agents/loading.py index d9b1df5bd0e..ab145cc788a 100644 --- a/libs/langchain/langchain/agents/loading.py +++ b/libs/langchain/langchain/agents/loading.py @@ -20,7 +20,10 @@ URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/age def _load_agent_from_tools( - config: dict, llm: BaseLanguageModel, tools: list[Tool], **kwargs: Any + config: dict, + llm: BaseLanguageModel, + tools: list[Tool], + **kwargs: Any, ) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: config_type = config.pop("_type") if config_type not in AGENT_TO_CLASS: @@ -87,7 +90,7 @@ def load_agent_from_config( if "output_parser" in config: logger.warning( "Currently loading output parsers on agent is not supported, " - "will just use the default one." + "will just use the default one.", ) del config["output_parser"] @@ -97,7 +100,8 @@ def load_agent_from_config( @deprecated("0.1.0", removal="1.0") def load_agent( - path: Union[str, Path], **kwargs: Any + path: Union[str, Path], + **kwargs: Any, ) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: """Unified method for loading an agent from LangChainHub or local fs. @@ -123,7 +127,8 @@ def load_agent( def _load_agent_from_file( - file: Union[str, Path], **kwargs: Any + file: Union[str, Path], + **kwargs: Any, ) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: """Load agent from file.""" valid_suffixes = {"json", "yaml"} diff --git a/libs/langchain/langchain/agents/mrkl/base.py b/libs/langchain/langchain/agents/mrkl/base.py index 7e76b4caf5f..9bc04129e93 100644 --- a/libs/langchain/langchain/agents/mrkl/base.py +++ b/libs/langchain/langchain/agents/mrkl/base.py @@ -104,7 +104,7 @@ class ZeroShotAgent(Agent): tool_strings = render_text_description(list(tools)) tool_names = ", ".join([tool.name for tool in tools]) format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) + template = f"{prefix}\n\n{tool_strings}\n\n{format_instructions}\n\n{suffix}" if input_variables: return PromptTemplate(template=template, input_variables=input_variables) return PromptTemplate.from_template(template) @@ -186,7 +186,10 @@ class MRKLChain(AgentExecutor): @classmethod def from_chains( - cls, llm: BaseLanguageModel, chains: list[ChainConfig], **kwargs: Any + cls, + llm: BaseLanguageModel, + chains: list[ChainConfig], + **kwargs: Any, ) -> AgentExecutor: """User-friendly way to initialize the MRKL chain. diff --git a/libs/langchain/langchain/agents/mrkl/output_parser.py b/libs/langchain/langchain/agents/mrkl/output_parser.py index 9be800f297b..388e28f6f25 100644 --- a/libs/langchain/langchain/agents/mrkl/output_parser.py +++ b/libs/langchain/langchain/agents/mrkl/output_parser.py @@ -53,7 +53,8 @@ class MRKLOutputParser(AgentOutputParser): start_index = text.find(FINAL_ANSWER_ACTION) + len(FINAL_ANSWER_ACTION) end_index = text.find("\n\n", start_index) return AgentFinish( - {"output": text[start_index:end_index].strip()}, text[:end_index] + {"output": text[start_index:end_index].strip()}, + text[:end_index], ) msg = f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}" raise OutputParserException(msg) @@ -70,7 +71,8 @@ class MRKLOutputParser(AgentOutputParser): if includes_answer: return AgentFinish( - {"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text + {"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, + text, ) if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL): @@ -82,7 +84,9 @@ class MRKLOutputParser(AgentOutputParser): send_to_llm=True, ) if not re.search( - r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL + r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", + text, + re.DOTALL, ): msg = f"Could not parse LLM output: `{text}`" raise OutputParserException( diff --git a/libs/langchain/langchain/agents/mrkl/prompt.py b/libs/langchain/langchain/agents/mrkl/prompt.py index db6827b5ec7..a713dac6dc8 100644 --- a/libs/langchain/langchain/agents/mrkl/prompt.py +++ b/libs/langchain/langchain/agents/mrkl/prompt.py @@ -1,5 +1,4 @@ -# flake8: noqa -PREFIX = """Answer the following questions as best you can. You have access to the following tools:""" +PREFIX = """Answer the following questions as best you can. You have access to the following tools:""" # noqa: E501 FORMAT_INSTRUCTIONS = """Use the following format: Question: the input question you must answer diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index 915a4a0b270..4706d02df38 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -273,7 +273,10 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): @override def invoke( - self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any + self, + input: dict, + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> OutputType: """Invoke assistant. @@ -314,13 +317,15 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): inheritable_metadata=config.get("metadata"), ) run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") or self.get_name() + dumpd(self), + input, + name=config.get("run_name") or self.get_name(), ) try: # Being run within AgentExecutor and there are tool outputs to submit. if self.as_agent and input.get("intermediate_steps"): tool_outputs = self._parse_intermediate_steps( - input["intermediate_steps"] + input["intermediate_steps"], ) run = self.client.beta.threads.runs.submit_tool_outputs(**tool_outputs) # Starting a new thread and a new run. @@ -332,7 +337,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): "content": input["content"], "metadata": input.get("message_metadata"), "attachments": input.get("attachments"), - } + }, ], "metadata": input.get("thread_metadata"), } @@ -401,7 +406,10 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): @override async def ainvoke( - self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any + self, + input: dict, + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> OutputType: """Async invoke assistant. @@ -442,16 +450,18 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): inheritable_metadata=config.get("metadata"), ) run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") or self.get_name() + dumpd(self), + input, + name=config.get("run_name") or self.get_name(), ) try: # Being run within AgentExecutor and there are tool outputs to submit. if self.as_agent and input.get("intermediate_steps"): tool_outputs = await self._aparse_intermediate_steps( - input["intermediate_steps"] + input["intermediate_steps"], ) run = await self.async_client.beta.threads.runs.submit_tool_outputs( - **tool_outputs + **tool_outputs, ) # Starting a new thread and a new run. elif "thread_id" not in input: @@ -461,7 +471,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): "role": "user", "content": input["content"], "metadata": input.get("message_metadata"), - } + }, ], "metadata": input.get("thread_metadata"), } @@ -479,7 +489,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): # framework. else: run = await self.async_client.beta.threads.runs.submit_tool_outputs( - **input + **input, ) run = await self._await_for_run(run.id, run.thread_id) except BaseException as e: @@ -495,7 +505,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): return response def _parse_intermediate_steps( - self, intermediate_steps: list[tuple[OpenAIAssistantAction, str]] + self, + intermediate_steps: list[tuple[OpenAIAssistantAction, str]], ) -> dict: last_action, last_output = intermediate_steps[-1] run = self._wait_for_run(last_action.run_id, last_action.thread_id) @@ -575,7 +586,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): ) messages = self.client.beta.threads.messages.list( - run.thread_id, order="asc" + run.thread_id, + order="asc", ) new_messages = [msg for msg in messages if msg.run_id == run.id] if not self.as_agent: @@ -633,7 +645,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): log="", run_id=run.id, thread_id=run.thread_id, - ) + ), ) return actions run_info = json.dumps(run.dict(), indent=2) @@ -650,7 +662,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): return run async def _aparse_intermediate_steps( - self, intermediate_steps: list[tuple[OpenAIAssistantAction, str]] + self, + intermediate_steps: list[tuple[OpenAIAssistantAction, str]], ) -> dict: last_action, last_output = intermediate_steps[-1] run = self._wait_for_run(last_action.run_id, last_action.thread_id) @@ -730,7 +743,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): ) messages = await self.async_client.beta.threads.messages.list( - run.thread_id, order="asc" + run.thread_id, + order="asc", ) new_messages = [msg for msg in messages if msg.run_id == run.id] if not self.as_agent: @@ -784,7 +798,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): log="", run_id=run.id, thread_id=run.thread_id, - ) + ), ) return actions run_info = json.dumps(run.dict(), indent=2) @@ -795,7 +809,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): in_progress = True while in_progress: run = await self.async_client.beta.threads.runs.retrieve( - run_id, thread_id=thread_id + run_id, + thread_id=thread_id, ) in_progress = run.status in ("in_progress", "queued") if in_progress: diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index 32db4de3248..76d8e70673f 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -161,7 +161,9 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): prompt = self.prompt.format_prompt(**full_inputs) messages = prompt.to_messages() predicted_message = await self.llm.apredict_messages( - messages, functions=self.functions, callbacks=callbacks + messages, + functions=self.functions, + callbacks=callbacks, ) return self.output_parser._parse_ai_message(predicted_message) @@ -188,12 +190,15 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): if early_stopping_method == "force": # `force` just returns a constant string return AgentFinish( - {"output": "Agent stopped due to iteration limit or time limit."}, "" + {"output": "Agent stopped due to iteration limit or time limit."}, + "", ) if early_stopping_method == "generate": # Generate does one final forward pass agent_decision = self.plan( - intermediate_steps, with_functions=False, **kwargs + intermediate_steps, + with_functions=False, + **kwargs, ) if isinstance(agent_decision, AgentFinish): return agent_decision @@ -209,7 +214,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): def create_prompt( cls, system_message: Optional[SystemMessage] = SystemMessage( - content="You are a helpful AI assistant." + content="You are a helpful AI assistant.", ), extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, ) -> ChatPromptTemplate: @@ -233,7 +238,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): *_prompts, HumanMessagePromptTemplate.from_template("{input}"), MessagesPlaceholder(variable_name="agent_scratchpad"), - ] + ], ) return ChatPromptTemplate(messages=messages) @@ -245,7 +250,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): callback_manager: Optional[BaseCallbackManager] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, system_message: Optional[SystemMessage] = SystemMessage( - content="You are a helpful AI assistant." + content="You are a helpful AI assistant.", ), **kwargs: Any, ) -> BaseSingleActionAgent: @@ -274,7 +279,9 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): def create_openai_functions_agent( - llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate + llm: BaseLanguageModel, + tools: Sequence[BaseTool], + prompt: ChatPromptTemplate, ) -> Runnable: """Create an agent that uses OpenAI function calling. @@ -357,8 +364,8 @@ def create_openai_functions_agent( return ( RunnablePassthrough.assign( agent_scratchpad=lambda x: format_to_openai_function_messages( - x["intermediate_steps"] - ) + x["intermediate_steps"], + ), ) | prompt | llm_with_tools diff --git a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py index 7824dfffce7..26fb42dee21 100644 --- a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py @@ -95,7 +95,8 @@ def _parse_ai_message(message: BaseMessage) -> Union[list[AgentAction], AgentFin return final_tools return AgentFinish( - return_values={"output": message.content}, log=str(message.content) + return_values={"output": message.content}, + log=str(message.content), ) @@ -190,7 +191,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): }, "required": ["action_name", "action"], }, - } + }, }, "required": ["actions"], }, @@ -222,7 +223,9 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): prompt = self.prompt.format_prompt(**full_inputs) messages = prompt.to_messages() predicted_message = self.llm.predict_messages( - messages, functions=self.functions, callbacks=callbacks + messages, + functions=self.functions, + callbacks=callbacks, ) return _parse_ai_message(predicted_message) @@ -251,7 +254,9 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): prompt = self.prompt.format_prompt(**full_inputs) messages = prompt.to_messages() predicted_message = await self.llm.apredict_messages( - messages, functions=self.functions, callbacks=callbacks + messages, + functions=self.functions, + callbacks=callbacks, ) return _parse_ai_message(predicted_message) @@ -259,7 +264,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): def create_prompt( cls, system_message: Optional[SystemMessage] = SystemMessage( - content="You are a helpful AI assistant." + content="You are a helpful AI assistant.", ), extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, ) -> BasePromptTemplate: @@ -283,7 +288,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): *_prompts, HumanMessagePromptTemplate.from_template("{input}"), MessagesPlaceholder(variable_name="agent_scratchpad"), - ] + ], ) return ChatPromptTemplate(messages=messages) @@ -295,7 +300,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): callback_manager: Optional[BaseCallbackManager] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, system_message: Optional[SystemMessage] = SystemMessage( - content="You are a helpful AI assistant." + content="You are a helpful AI assistant.", ), **kwargs: Any, ) -> BaseMultiActionAgent: diff --git a/libs/langchain/langchain/agents/openai_tools/base.py b/libs/langchain/langchain/agents/openai_tools/base.py index f1ba9995ad0..fb3a767c4e9 100644 --- a/libs/langchain/langchain/agents/openai_tools/base.py +++ b/libs/langchain/langchain/agents/openai_tools/base.py @@ -86,21 +86,21 @@ def create_openai_tools_agent( ) """ missing_vars = {"agent_scratchpad"}.difference( - prompt.input_variables + list(prompt.partial_variables) + prompt.input_variables + list(prompt.partial_variables), ) if missing_vars: msg = f"Prompt missing required variables: {missing_vars}" raise ValueError(msg) llm_with_tools = llm.bind( - tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools] + tools=[convert_to_openai_tool(tool, strict=strict) for tool in tools], ) return ( RunnablePassthrough.assign( agent_scratchpad=lambda x: format_to_openai_tool_messages( - x["intermediate_steps"] - ) + x["intermediate_steps"], + ), ) | prompt | llm_with_tools diff --git a/libs/langchain/langchain/agents/output_parsers/openai_functions.py b/libs/langchain/langchain/agents/output_parsers/openai_functions.py index 26af5ddd964..57b47770304 100644 --- a/libs/langchain/langchain/agents/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/agents/output_parsers/openai_functions.py @@ -75,11 +75,15 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser): ) return AgentFinish( - return_values={"output": message.content}, log=str(message.content) + return_values={"output": message.content}, + log=str(message.content), ) def parse_result( - self, result: list[Generation], *, partial: bool = False + self, + result: list[Generation], + *, + partial: bool = False, ) -> Union[AgentAction, AgentFinish]: if not isinstance(result[0], ChatGeneration): msg = "This output parser only works on ChatGeneration output" diff --git a/libs/langchain/langchain/agents/output_parsers/openai_tools.py b/libs/langchain/langchain/agents/output_parsers/openai_tools.py index 42899aca570..3e612f625b3 100644 --- a/libs/langchain/langchain/agents/output_parsers/openai_tools.py +++ b/libs/langchain/langchain/agents/output_parsers/openai_tools.py @@ -30,7 +30,7 @@ def parse_ai_message_to_openai_tool_action( log=action.log, message_log=action.message_log, tool_call_id=action.tool_call_id, - ) + ), ) else: final_actions.append(action) @@ -54,7 +54,10 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser): return "openai-tools-agent-output-parser" def parse_result( - self, result: list[Generation], *, partial: bool = False + self, + result: list[Generation], + *, + partial: bool = False, ) -> Union[list[AgentAction], AgentFinish]: if not isinstance(result[0], ChatGeneration): msg = "This output parser only works on ChatGeneration output" diff --git a/libs/langchain/langchain/agents/output_parsers/react_json_single_input.py b/libs/langchain/langchain/agents/output_parsers/react_json_single_input.py index 248852cebc5..2613cde6c75 100644 --- a/libs/langchain/langchain/agents/output_parsers/react_json_single_input.py +++ b/libs/langchain/langchain/agents/output_parsers/react_json_single_input.py @@ -67,7 +67,9 @@ class ReActJsonSingleInputOutputParser(AgentOutputParser): ) raise OutputParserException(msg) return AgentAction( - response["action"], response.get("action_input", {}), text + response["action"], + response.get("action_input", {}), + text, ) except Exception: diff --git a/libs/langchain/langchain/agents/output_parsers/react_single_input.py b/libs/langchain/langchain/agents/output_parsers/react_single_input.py index 90f058be4ab..20a66dbdbfa 100644 --- a/libs/langchain/langchain/agents/output_parsers/react_single_input.py +++ b/libs/langchain/langchain/agents/output_parsers/react_single_input.py @@ -67,7 +67,8 @@ class ReActSingleInputOutputParser(AgentOutputParser): if includes_answer: return AgentFinish( - {"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text + {"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, + text, ) if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL): @@ -79,7 +80,9 @@ class ReActSingleInputOutputParser(AgentOutputParser): send_to_llm=True, ) if not re.search( - r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL + r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", + text, + re.DOTALL, ): msg = f"Could not parse LLM output: `{text}`" raise OutputParserException( diff --git a/libs/langchain/langchain/agents/output_parsers/tools.py b/libs/langchain/langchain/agents/output_parsers/tools.py index 4461eeced3f..b901270581b 100644 --- a/libs/langchain/langchain/agents/output_parsers/tools.py +++ b/libs/langchain/langchain/agents/output_parsers/tools.py @@ -33,7 +33,8 @@ def parse_ai_message_to_tool_action( else: if not message.additional_kwargs.get("tool_calls"): return AgentFinish( - return_values={"output": message.content}, log=str(message.content) + return_values={"output": message.content}, + log=str(message.content), ) # Best-effort parsing tool_calls = [] @@ -43,7 +44,7 @@ def parse_ai_message_to_tool_action( try: args = json.loads(function["arguments"] or "{}") tool_calls.append( - ToolCall(name=function_name, args=args, id=tool_call["id"]) + ToolCall(name=function_name, args=args, id=tool_call["id"]), ) except JSONDecodeError: msg = ( @@ -71,7 +72,7 @@ def parse_ai_message_to_tool_action( log=log, message_log=[message], tool_call_id=tool_call["id"], - ) + ), ) return actions @@ -90,7 +91,10 @@ class ToolsAgentOutputParser(MultiActionAgentOutputParser): return "tools-agent-output-parser" def parse_result( - self, result: list[Generation], *, partial: bool = False + self, + result: list[Generation], + *, + partial: bool = False, ) -> Union[list[AgentAction], AgentFinish]: if not isinstance(result[0], ChatGeneration): msg = "This output parser only works on ChatGeneration output" diff --git a/libs/langchain/langchain/agents/react/agent.py b/libs/langchain/langchain/agents/react/agent.py index 8fcbb46724a..622a03855a5 100644 --- a/libs/langchain/langchain/agents/react/agent.py +++ b/libs/langchain/langchain/agents/react/agent.py @@ -118,7 +118,7 @@ def create_react_agent( prompt = PromptTemplate.from_template(template) """ # noqa: E501 missing_vars = {"tools", "tool_names", "agent_scratchpad"}.difference( - prompt.input_variables + list(prompt.partial_variables) + prompt.input_variables + list(prompt.partial_variables), ) if missing_vars: msg = f"Prompt missing required variables: {missing_vars}" diff --git a/libs/langchain/langchain/agents/react/textworld_prompt.py b/libs/langchain/langchain/agents/react/textworld_prompt.py index 26cfd49aa29..797605b4089 100644 --- a/libs/langchain/langchain/agents/react/textworld_prompt.py +++ b/libs/langchain/langchain/agents/react/textworld_prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate EXAMPLES = [ diff --git a/libs/langchain/langchain/agents/react/wiki_prompt.py b/libs/langchain/langchain/agents/react/wiki_prompt.py index 9db6cf92918..497dd53048a 100644 --- a/libs/langchain/langchain/agents/react/wiki_prompt.py +++ b/libs/langchain/langchain/agents/react/wiki_prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate EXAMPLES = [ @@ -16,7 +15,7 @@ Thought: I need to instead search High Plains (United States). Action: Search[High Plains (United States)] Observation: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3] Thought: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft. -Action: Finish[1,800 to 7,000 ft]""", +Action: Finish[1,800 to 7,000 ft]""", # noqa: E501 """Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who? Thought: The question simplifies to "The Simpsons" character Milhouse is named after who. I only need to search Milhouse and find who it is named after. Action: Search[Milhouse] @@ -25,16 +24,16 @@ Thought: The paragraph does not tell who Milhouse is named after, maybe I can lo Action: Lookup[named after] Observation: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous. Thought: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon. -Action: Finish[Richard Nixon]""", +Action: Finish[Richard Nixon]""", # noqa: E501 """Question: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture? Thought: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups. Action: Search[Adam Clayton Powell] -Observation: Could not find [Adam Clayton Powell]. Similar: [’Adam Clayton Powell III’, ’Seventh Avenue (Manhattan)’, ’Adam Clayton Powell Jr. State Office Building’, ’Isabel Washington Powell’, ’Adam Powell’, ’Adam Clayton Powell (film)’, ’Giancarlo Esposito’]. +Observation: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito']. Thought: To find the documentary, I can search Adam Clayton Powell (film). Action: Search[Adam Clayton Powell (film)] Observation: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg. The film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience. Thought: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture. -Action: Finish[The Saimaa Gesture]""", +Action: Finish[The Saimaa Gesture]""", # noqa: E501 """Question: What profession does Nicholas Ray and Elia Kazan have in common? Thought: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common. Action: Search[Nicholas Ray] @@ -43,16 +42,16 @@ Thought: Professions of Nicholas Ray are director, screenwriter, and actor. I ne Action: Search[Elia Kazan] Observation: Elia Kazan was an American film and theatre director, producer, screenwriter and actor. Thought: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor. -Action: Finish[director, screenwriter, actor]""", - """Question: Which magazine was started first Arthur’s Magazine or First for Women? -Thought: I need to search Arthur’s Magazine and First for Women, and find which was started first. -Action: Search[Arthur’s Magazine] -Observation: Arthur’s Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century. -Thought: Arthur’s Magazine was started in 1844. I need to search First for Women next. +Action: Finish[director, screenwriter, actor]""", # noqa: E501 + """Question: Which magazine was started first Arthur's Magazine or First for Women? +Thought: I need to search Arthur's Magazine and First for Women, and find which was started first. +Action: Search[Arthur's Magazine] +Observation: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century. +Thought: Arthur's Magazine was started in 1844. I need to search First for Women next. Action: Search[First for Women] -Observation: First for Women is a woman’s magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989. -Thought: First for Women was started in 1989. 1844 (Arthur’s Magazine) < 1989 (First for Women), so Arthur’s Magazine was started first. -Action: Finish[Arthur’s Magazine]""", +Observation: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989. +Thought: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first. +Action: Finish[Arthur's Magazine]""", # noqa: E501 """Question: Were Pavel Urysohn and Leonid Levin known for the same type of work? Thought: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same. Action: Search[Pavel Urysohn] @@ -61,7 +60,7 @@ Thought: Pavel Urysohn is a mathematician. I need to search Leonid Levin next an Action: Search[Leonid Levin] Observation: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist. Thought: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work. -Action: Finish[yes]""", +Action: Finish[yes]""", # noqa: E501 ] SUFFIX = """\nQuestion: {input} {agent_scratchpad}""" diff --git a/libs/langchain/langchain/agents/schema.py b/libs/langchain/langchain/agents/schema.py index 664ec9ec8ed..7cad47f4c44 100644 --- a/libs/langchain/langchain/agents/schema.py +++ b/libs/langchain/langchain/agents/schema.py @@ -12,7 +12,8 @@ class AgentScratchPadChatPromptTemplate(ChatPromptTemplate): return False def _construct_agent_scratchpad( - self, intermediate_steps: list[tuple[AgentAction, str]] + self, + intermediate_steps: list[tuple[AgentAction, str]], ) -> str: if len(intermediate_steps) == 0: return "" @@ -29,6 +30,6 @@ class AgentScratchPadChatPromptTemplate(ChatPromptTemplate): def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]: intermediate_steps = kwargs.pop("intermediate_steps") kwargs["agent_scratchpad"] = self._construct_agent_scratchpad( - intermediate_steps + intermediate_steps, ) return kwargs diff --git a/libs/langchain/langchain/agents/self_ask_with_search/base.py b/libs/langchain/langchain/agents/self_ask_with_search/base.py index 9bf800be787..1078cb0e7a3 100644 --- a/libs/langchain/langchain/agents/self_ask_with_search/base.py +++ b/libs/langchain/langchain/agents/self_ask_with_search/base.py @@ -76,7 +76,9 @@ class SelfAskWithSearchChain(AgentExecutor): self, llm: BaseLanguageModel, search_chain: Union[ - GoogleSerperAPIWrapper, SearchApiAPIWrapper, SerpAPIWrapper + GoogleSerperAPIWrapper, + SearchApiAPIWrapper, + SerpAPIWrapper, ], **kwargs: Any, ): @@ -92,7 +94,9 @@ class SelfAskWithSearchChain(AgentExecutor): def create_self_ask_with_search_agent( - llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: BasePromptTemplate + llm: BaseLanguageModel, + tools: Sequence[BaseTool], + prompt: BasePromptTemplate, ) -> Runnable: """Create an agent that uses self-ask with search prompting. @@ -180,7 +184,7 @@ def create_self_ask_with_search_agent( prompt = PromptTemplate.from_template(template) """ # noqa: E501 missing_vars = {"agent_scratchpad"}.difference( - prompt.input_variables + list(prompt.partial_variables) + prompt.input_variables + list(prompt.partial_variables), ) if missing_vars: msg = f"Prompt missing required variables: {missing_vars}" diff --git a/libs/langchain/langchain/agents/self_ask_with_search/prompt.py b/libs/langchain/langchain/agents/self_ask_with_search/prompt.py index c9154785cd9..262e66a881c 100644 --- a/libs/langchain/langchain/agents/self_ask_with_search/prompt.py +++ b/libs/langchain/langchain/agents/self_ask_with_search/prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate _DEFAULT_TEMPLATE = """Question: Who lived longer, Muhammad Ali or Alan Turing? diff --git a/libs/langchain/langchain/agents/structured_chat/base.py b/libs/langchain/langchain/agents/structured_chat/base.py index bc1464dbc5e..8a374faa242 100644 --- a/libs/langchain/langchain/agents/structured_chat/base.py +++ b/libs/langchain/langchain/agents/structured_chat/base.py @@ -35,7 +35,7 @@ class StructuredChatAgent(Agent): """Structured Chat Agent.""" output_parser: AgentOutputParser = Field( - default_factory=StructuredChatOutputParserWithRetries + default_factory=StructuredChatOutputParserWithRetries, ) """Output parser for the agent.""" @@ -50,7 +50,8 @@ class StructuredChatAgent(Agent): return "Thought:" def _construct_scratchpad( - self, intermediate_steps: list[tuple[AgentAction, str]] + self, + intermediate_steps: list[tuple[AgentAction, str]], ) -> str: agent_scratchpad = super()._construct_scratchpad(intermediate_steps) if not isinstance(agent_scratchpad, str): @@ -70,7 +71,9 @@ class StructuredChatAgent(Agent): @classmethod def _get_default_output_parser( - cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any + cls, + llm: Optional[BaseLanguageModel] = None, + **kwargs: Any, ) -> AgentOutputParser: return StructuredChatOutputParserWithRetries.from_llm(llm=llm) @@ -96,7 +99,7 @@ class StructuredChatAgent(Agent): formatted_tools = "\n".join(tool_strings) tool_names = ", ".join([tool.name for tool in tools]) format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) + template = f"{prefix}\n\n{formatted_tools}\n\n{format_instructions}\n\n{suffix}" if input_variables is None: input_variables = ["input", "agent_scratchpad"] _memory_prompts = memory_prompts or [] @@ -275,7 +278,7 @@ def create_structured_chat_agent( ) """ # noqa: E501 missing_vars = {"tools", "tool_names", "agent_scratchpad"}.difference( - prompt.input_variables + list(prompt.partial_variables) + prompt.input_variables + list(prompt.partial_variables), ) if missing_vars: msg = f"Prompt missing required variables: {missing_vars}" diff --git a/libs/langchain/langchain/agents/structured_chat/output_parser.py b/libs/langchain/langchain/agents/structured_chat/output_parser.py index 0e6e072c970..062c61e7c81 100644 --- a/libs/langchain/langchain/agents/structured_chat/output_parser.py +++ b/libs/langchain/langchain/agents/structured_chat/output_parser.py @@ -43,7 +43,9 @@ class StructuredChatOutputParser(AgentOutputParser): if response["action"] == "Final Answer": return AgentFinish({"output": response["action_input"]}, text) return AgentAction( - response["action"], response.get("action_input", {}), text + response["action"], + response.get("action_input", {}), + text, ) return AgentFinish({"output": text}, text) except Exception as e: @@ -88,7 +90,8 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser): if llm is not None: base_parser = base_parser or StructuredChatOutputParser() output_fixing_parser: OutputFixingParser = OutputFixingParser.from_llm( - llm=llm, parser=base_parser + llm=llm, + parser=base_parser, ) return cls(output_fixing_parser=output_fixing_parser) if base_parser is not None: diff --git a/libs/langchain/langchain/agents/structured_chat/prompt.py b/libs/langchain/langchain/agents/structured_chat/prompt.py index 98d8bb37331..4400c56cc20 100644 --- a/libs/langchain/langchain/agents/structured_chat/prompt.py +++ b/libs/langchain/langchain/agents/structured_chat/prompt.py @@ -1,5 +1,4 @@ -# flake8: noqa -PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" +PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" # noqa: E501 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). Valid "action" values: "Final Answer" or {tool_names} @@ -30,6 +29,6 @@ Action: "action": "Final Answer", "action_input": "Final response to human" }}}} -```""" +```""" # noqa: E501 SUFFIX = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. -Thought:""" +Thought:""" # noqa: E501 diff --git a/libs/langchain/langchain/agents/tool_calling_agent/base.py b/libs/langchain/langchain/agents/tool_calling_agent/base.py index fd412fea703..4e22099c4b0 100644 --- a/libs/langchain/langchain/agents/tool_calling_agent/base.py +++ b/libs/langchain/langchain/agents/tool_calling_agent/base.py @@ -87,7 +87,7 @@ def create_tool_calling_agent( messages will be passed in here. """ missing_vars = {"agent_scratchpad"}.difference( - prompt.input_variables + list(prompt.partial_variables) + prompt.input_variables + list(prompt.partial_variables), ) if missing_vars: msg = f"Prompt missing required variables: {missing_vars}" @@ -102,7 +102,7 @@ def create_tool_calling_agent( return ( RunnablePassthrough.assign( - agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"]) + agent_scratchpad=lambda x: message_formatter(x["intermediate_steps"]), ) | prompt | llm_with_tools diff --git a/libs/langchain/langchain/agents/xml/base.py b/libs/langchain/langchain/agents/xml/base.py index 347d474dfc2..d6a77c9dee7 100644 --- a/libs/langchain/langchain/agents/xml/base.py +++ b/libs/langchain/langchain/agents/xml/base.py @@ -52,7 +52,7 @@ class XMLAgent(BaseSingleActionAgent): def get_default_prompt() -> ChatPromptTemplate: base_prompt = ChatPromptTemplate.from_template(agent_instructions) return base_prompt + AIMessagePromptTemplate.from_template( - "{intermediate_steps}" + "{intermediate_steps}", ) @staticmethod @@ -205,7 +205,7 @@ def create_xml_agent( prompt = PromptTemplate.from_template(template) """ # noqa: E501 missing_vars = {"tools", "agent_scratchpad"}.difference( - prompt.input_variables + list(prompt.partial_variables) + prompt.input_variables + list(prompt.partial_variables), ) if missing_vars: msg = f"Prompt missing required variables: {missing_vars}" diff --git a/libs/langchain/langchain/agents/xml/prompt.py b/libs/langchain/langchain/agents/xml/prompt.py index 3972c6a07ad..31afe11a508 100644 --- a/libs/langchain/langchain/agents/xml/prompt.py +++ b/libs/langchain/langchain/agents/xml/prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa # TODO: deprecate agent_instructions = """You are a helpful assistant. Help the user answer any questions. @@ -19,4 +18,4 @@ When you are done, respond with a final answer between None: # If two calls are made in a row, this resets the state self.done.clear() diff --git a/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py b/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py index db46a3e146b..b9f91788e3f 100644 --- a/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py +++ b/libs/langchain/langchain/callbacks/streaming_aiter_final_only.py @@ -61,7 +61,10 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.answer_reached = False async def on_llm_start( - self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any + self, + serialized: dict[str, Any], + prompts: list[str], + **kwargs: Any, ) -> None: # If two calls are made in a row, this resets the state self.done.clear() diff --git a/libs/langchain/langchain/callbacks/streaming_stdout_final_only.py b/libs/langchain/langchain/callbacks/streaming_stdout_final_only.py index 674e5579175..5dbbf700e8e 100644 --- a/libs/langchain/langchain/callbacks/streaming_stdout_final_only.py +++ b/libs/langchain/langchain/callbacks/streaming_stdout_final_only.py @@ -62,7 +62,10 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): self.answer_reached = False def on_llm_start( - self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any + self, + serialized: dict[str, Any], + prompts: list[str], + **kwargs: Any, ) -> None: """Run when LLM starts running.""" self.answer_reached = False diff --git a/libs/langchain/langchain/callbacks/tracers/logging.py b/libs/langchain/langchain/callbacks/tracers/logging.py index a6bad951a9d..35d075dd0d3 100644 --- a/libs/langchain/langchain/callbacks/tracers/logging.py +++ b/libs/langchain/langchain/callbacks/tracers/logging.py @@ -42,5 +42,5 @@ class LoggingCallbackHandler(FunctionCallbackHandler): crumbs_str = "" self.function_callback( f"{get_colored_text('[text]', color='blue')}" - f" {get_bolded_text(f'{crumbs_str}New text:')}\n{text}" + f" {get_bolded_text(f'{crumbs_str}New text:')}\n{text}", ) diff --git a/libs/langchain/langchain/callbacks/trubrics_callback.py b/libs/langchain/langchain/callbacks/trubrics_callback.py index 783bb8961ef..94812850a3f 100644 --- a/libs/langchain/langchain/callbacks/trubrics_callback.py +++ b/libs/langchain/langchain/callbacks/trubrics_callback.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "TrubricsCallbackHandler": "langchain_community.callbacks.trubrics_callback" + "TrubricsCallbackHandler": "langchain_community.callbacks.trubrics_callback", } _import_attribute = create_importer(__file__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/callbacks/wandb_callback.py b/libs/langchain/langchain/callbacks/wandb_callback.py index 81d06e16407..662e63450f5 100644 --- a/libs/langchain/langchain/callbacks/wandb_callback.py +++ b/libs/langchain/langchain/callbacks/wandb_callback.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "WandbCallbackHandler": "langchain_community.callbacks.wandb_callback" + "WandbCallbackHandler": "langchain_community.callbacks.wandb_callback", } _import_attribute = create_importer(__file__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/callbacks/whylabs_callback.py b/libs/langchain/langchain/callbacks/whylabs_callback.py index 9839fa39782..34fd616743e 100644 --- a/libs/langchain/langchain/callbacks/whylabs_callback.py +++ b/libs/langchain/langchain/callbacks/whylabs_callback.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "WhyLabsCallbackHandler": "langchain_community.callbacks.whylabs_callback" + "WhyLabsCallbackHandler": "langchain_community.callbacks.whylabs_callback", } _import_attribute = create_importer(__file__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/chains/api/base.py b/libs/langchain/langchain/chains/api/base.py index 7fe4b9ce81a..df2f9623e90 100644 --- a/libs/langchain/langchain/chains/api/base.py +++ b/libs/langchain/langchain/chains/api/base.py @@ -287,7 +287,8 @@ try: _run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose) api_url = api_url.strip() if self.limit_to_domains and not _check_in_allowed_domain( - api_url, self.limit_to_domains + api_url, + self.limit_to_domains, ): msg = ( f"{api_url} is not in the allowed domains: {self.limit_to_domains}" @@ -295,7 +296,10 @@ try: raise ValueError(msg) api_response = self.requests_wrapper.get(api_url) _run_manager.on_text( - str(api_response), color="yellow", end="\n", verbose=self.verbose + str(api_response), + color="yellow", + end="\n", + verbose=self.verbose, ) answer = self.api_answer_chain.predict( question=question, @@ -321,11 +325,15 @@ try: callbacks=_run_manager.get_child(), ) await _run_manager.on_text( - api_url, color="green", end="\n", verbose=self.verbose + api_url, + color="green", + end="\n", + verbose=self.verbose, ) api_url = api_url.strip() if self.limit_to_domains and not _check_in_allowed_domain( - api_url, self.limit_to_domains + api_url, + self.limit_to_domains, ): msg = ( f"{api_url} is not in the allowed domains: {self.limit_to_domains}" @@ -333,7 +341,10 @@ try: raise ValueError(msg) api_response = await self.requests_wrapper.aget(api_url) await _run_manager.on_text( - str(api_response), color="yellow", end="\n", verbose=self.verbose + str(api_response), + color="yellow", + end="\n", + verbose=self.verbose, ) answer = await self.api_answer_chain.apredict( question=question, diff --git a/libs/langchain/langchain/chains/api/news_docs.py b/libs/langchain/langchain/chains/api/news_docs.py index b29f16c10db..7d0062403aa 100644 --- a/libs/langchain/langchain/chains/api/news_docs.py +++ b/libs/langchain/langchain/chains/api/news_docs.py @@ -1,4 +1,3 @@ -# flake8: noqa NEWS_DOCS = """API documentation: Endpoint: https://newsapi.org Top headlines /v2/top-headlines @@ -29,4 +28,4 @@ Response object content | string | The unformatted content of the article, where available. This is truncated to 200 chars. Use page size: 2 -""" +""" # noqa: E501 diff --git a/libs/langchain/langchain/chains/api/open_meteo_docs.py b/libs/langchain/langchain/chains/api/open_meteo_docs.py index 4abd86fb83a..97c3ce3534b 100644 --- a/libs/langchain/langchain/chains/api/open_meteo_docs.py +++ b/libs/langchain/langchain/chains/api/open_meteo_docs.py @@ -1,4 +1,3 @@ -# flake8: noqa OPEN_METEO_DOCS = """BASE URL: https://api.open-meteo.com/ API Documentation @@ -30,4 +29,4 @@ showers Preceding hour sum mm (inch) Showers from convective precipitation in mi weathercode Instant WMO code Weather condition as a numeric code. Follow WMO weather interpretation codes. See table below for details. snow_depth Instant meters Snow depth on the ground freezinglevel_height Instant meters Altitude above sea level of the 0°C level -visibility Instant meters Viewing distance in meters. Influenced by low clouds, humidity and aerosols. Maximum visibility is approximately 24 km.""" +visibility Instant meters Viewing distance in meters. Influenced by low clouds, humidity and aerosols. Maximum visibility is approximately 24 km.""" # noqa: E501 diff --git a/libs/langchain/langchain/chains/api/podcast_docs.py b/libs/langchain/langchain/chains/api/podcast_docs.py index 9c4e5cbf827..a2b825f73cd 100644 --- a/libs/langchain/langchain/chains/api/podcast_docs.py +++ b/libs/langchain/langchain/chains/api/podcast_docs.py @@ -1,4 +1,3 @@ -# flake8: noqa PODCAST_DOCS = """API documentation: Endpoint: https://listen-api.listennotes.com/api/v2 GET /search @@ -25,4 +24,4 @@ id | integer | optional title_highlighted | string | optional Use page_size: 3 -""" +""" # noqa: E501 diff --git a/libs/langchain/langchain/chains/api/prompt.py b/libs/langchain/langchain/chains/api/prompt.py index 0ffc389ad3d..f3ba53ce933 100644 --- a/libs/langchain/langchain/chains/api/prompt.py +++ b/libs/langchain/langchain/chains/api/prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate API_URL_PROMPT_TEMPLATE = """You are given the below API Documentation: @@ -7,7 +6,7 @@ Using this documentation, generate the full API url to call for answering the us You should build the API url in order to get a response that is as short as possible, while still getting the necessary information to answer the question. Pay attention to deliberately exclude any unnecessary pieces of data in the API call. Question:{question} -API url:""" +API url:""" # noqa: E501 API_URL_PROMPT = PromptTemplate( input_variables=[ diff --git a/libs/langchain/langchain/chains/api/tmdb_docs.py b/libs/langchain/langchain/chains/api/tmdb_docs.py index 4634a80ae0b..c084bc7fc20 100644 --- a/libs/langchain/langchain/chains/api/tmdb_docs.py +++ b/libs/langchain/langchain/chains/api/tmdb_docs.py @@ -1,4 +1,3 @@ -# flake8: noqa TMDB_DOCS = """API documentation: Endpoint: https://api.themoviedb.org/3 GET /search/movie @@ -34,4 +33,4 @@ backdrop_path | string or null | optional popularity | number | optional vote_count | integer | optional video | boolean | optional -vote_average | number | optional""" +vote_average | number | optional""" # noqa: E501 diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 9052b7dd6a9..4ab2064c1d4 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -109,17 +109,20 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): ) def get_input_schema( - self, config: Optional[RunnableConfig] = None + self, + config: Optional[RunnableConfig] = None, ) -> type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model("ChainInput", **dict.fromkeys(self.input_keys, (Any, None))) def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, + config: Optional[RunnableConfig] = None, ) -> type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( - "ChainOutput", **dict.fromkeys(self.output_keys, (Any, None)) + "ChainOutput", + **dict.fromkeys(self.output_keys, (Any, None)), ) @override @@ -165,7 +168,9 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): ) final_outputs: dict[str, Any] = self.prep_outputs( - inputs, outputs, return_only_outputs + inputs, + outputs, + return_only_outputs, ) except BaseException as e: run_manager.on_chain_error(e) @@ -217,7 +222,9 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): else await self._acall(inputs) ) final_outputs: dict[str, Any] = await self.aprep_outputs( - inputs, outputs, return_only_outputs + inputs, + outputs, + return_only_outputs, ) except BaseException as e: await run_manager.on_chain_error(e) @@ -349,7 +356,10 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): `Chain.output_keys`. """ return await run_in_executor( - None, self._call, inputs, run_manager.get_sync() if run_manager else None + None, + self._call, + inputs, + run_manager.get_sync() if run_manager else None, ) @deprecated("0.1.0", alternative="invoke", removal="1.0") @@ -692,14 +702,20 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): raise ValueError(msg) return ( await self.acall( - args[0], callbacks=callbacks, tags=tags, metadata=metadata + args[0], + callbacks=callbacks, + tags=tags, + metadata=metadata, ) )[self.output_keys[0]] if kwargs and not args: return ( await self.acall( - kwargs, callbacks=callbacks, tags=tags, metadata=metadata + kwargs, + callbacks=callbacks, + tags=tags, + metadata=metadata, ) )[self.output_keys[0]] @@ -775,7 +791,9 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): @deprecated("0.1.0", alternative="batch", removal="1.0") def apply( - self, input_list: list[builtins.dict[str, Any]], callbacks: Callbacks = None + self, + input_list: list[builtins.dict[str, Any]], + callbacks: Callbacks = None, ) -> list[builtins.dict[str, str]]: """Call the chain on all inputs in the list.""" return [self(inputs, callbacks=callbacks) for inputs in input_list] diff --git a/libs/langchain/langchain/chains/chat_vector_db/prompts.py b/libs/langchain/langchain/chains/chat_vector_db/prompts.py index 19f7a210386..89689ebab74 100644 --- a/libs/langchain/langchain/chains/chat_vector_db/prompts.py +++ b/libs/langchain/langchain/chains/chat_vector_db/prompts.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. @@ -6,7 +5,7 @@ _template = """Given the following conversation and a follow up question, rephra Chat History: {chat_history} Follow Up Input: {question} -Standalone question:""" +Standalone question:""" # noqa: E501 CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. @@ -14,7 +13,7 @@ prompt_template = """Use the following pieces of context to answer the question {context} Question: {question} -Helpful Answer:""" +Helpful Answer:""" # noqa: E501 QA_PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index d8f00a099bd..5c24b64f1ab 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -47,7 +47,8 @@ class BaseCombineDocumentsChain(Chain, ABC): output_key: str = "output_text" #: :meta private: def get_input_schema( - self, config: Optional[RunnableConfig] = None + self, + config: Optional[RunnableConfig] = None, ) -> type[BaseModel]: return create_model( "CombineDocumentsInput", @@ -55,7 +56,8 @@ class BaseCombineDocumentsChain(Chain, ABC): ) def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, + config: Optional[RunnableConfig] = None, ) -> type[BaseModel]: return create_model( "CombineDocumentsOutput", @@ -112,7 +114,9 @@ class BaseCombineDocumentsChain(Chain, ABC): @abstractmethod async def acombine_docs( - self, docs: list[Document], **kwargs: Any + self, + docs: list[Document], + **kwargs: Any, ) -> tuple[str, dict]: """Combine documents into a single string. @@ -137,7 +141,9 @@ class BaseCombineDocumentsChain(Chain, ABC): # Other keys are assumed to be needed for LLM prediction other_keys = {k: v for k, v in inputs.items() if k != self.input_key} output, extra_return_dict = self.combine_docs( - docs, callbacks=_run_manager.get_child(), **other_keys + docs, + callbacks=_run_manager.get_child(), + **other_keys, ) extra_return_dict[self.output_key] = output return extra_return_dict @@ -153,7 +159,9 @@ class BaseCombineDocumentsChain(Chain, ABC): # Other keys are assumed to be needed for LLM prediction other_keys = {k: v for k, v in inputs.items() if k != self.input_key} output, extra_return_dict = await self.acombine_docs( - docs, callbacks=_run_manager.get_child(), **other_keys + docs, + callbacks=_run_manager.get_child(), + **other_keys, ) extra_return_dict[self.output_key] = output return extra_return_dict @@ -246,7 +254,8 @@ class AnalyzeDocumentChain(Chain): return self.combine_docs_chain.output_keys def get_input_schema( - self, config: Optional[RunnableConfig] = None + self, + config: Optional[RunnableConfig] = None, ) -> type[BaseModel]: return create_model( "AnalyzeDocumentChain", @@ -254,7 +263,8 @@ class AnalyzeDocumentChain(Chain): ) def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, + config: Optional[RunnableConfig] = None, ) -> type[BaseModel]: return self.combine_docs_chain.get_output_schema(config) @@ -271,5 +281,7 @@ class AnalyzeDocumentChain(Chain): other_keys: dict = {k: v for k, v in inputs.items() if k != self.input_key} other_keys[self.combine_docs_chain.input_key] = docs return self.combine_docs_chain( - other_keys, return_only_outputs=True, callbacks=_run_manager.get_child() + other_keys, + return_only_outputs=True, + callbacks=_run_manager.get_child(), ) diff --git a/libs/langchain/langchain/chains/combine_documents/map_reduce.py b/libs/langchain/langchain/chains/combine_documents/map_reduce.py index 9f8d3dbcc2e..8bd373ebd82 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/map_reduce.py @@ -112,7 +112,8 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): """Return the results of the map steps in the output.""" def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, + config: Optional[RunnableConfig] = None, ) -> type[BaseModel]: if self.return_intermediate_steps: return create_model( @@ -251,7 +252,10 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): for i, r in enumerate(map_results) ] result, extra_return_dict = self.reduce_documents_chain.combine_docs( - result_docs, token_max=token_max, callbacks=callbacks, **kwargs + result_docs, + token_max=token_max, + callbacks=callbacks, + **kwargs, ) if self.return_intermediate_steps: intermediate_steps = [r[question_result_key] for r in map_results] @@ -282,7 +286,10 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): for i, r in enumerate(map_results) ] result, extra_return_dict = await self.reduce_documents_chain.acombine_docs( - result_docs, token_max=token_max, callbacks=callbacks, **kwargs + result_docs, + token_max=token_max, + callbacks=callbacks, + **kwargs, ) if self.return_intermediate_steps: intermediate_steps = [r[question_result_key] for r in map_results] diff --git a/libs/langchain/langchain/chains/combine_documents/map_rerank.py b/libs/langchain/langchain/chains/combine_documents/map_rerank.py index 632bcf2fe9d..b83007ad25b 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_rerank.py +++ b/libs/langchain/langchain/chains/combine_documents/map_rerank.py @@ -92,7 +92,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): ) def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, + config: Optional[RunnableConfig] = None, ) -> type[BaseModel]: schema: dict[str, Any] = { self.output_key: (str, None), @@ -170,7 +171,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): return values def combine_docs( - self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any + self, + docs: list[Document], + callbacks: Callbacks = None, + **kwargs: Any, ) -> tuple[str, dict]: """Combine documents in a map rerank manner. @@ -194,7 +198,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): return self._process_results(docs, results) async def acombine_docs( - self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any + self, + docs: list[Document], + callbacks: Callbacks = None, + **kwargs: Any, ) -> tuple[str, dict]: """Combine documents in a map rerank manner. @@ -224,7 +231,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): ) -> tuple[str, dict]: typed_results = cast(list[dict], results) sorted_res = sorted( - zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key]) + zip(typed_results, docs), + key=lambda x: -int(x[0][self.rank_key]), ) output, document = sorted_res[0] extra_info = {} diff --git a/libs/langchain/langchain/chains/combine_documents/reduce.py b/libs/langchain/langchain/chains/combine_documents/reduce.py index fab6b53eea1..eb8648685bb 100644 --- a/libs/langchain/langchain/chains/combine_documents/reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/reduce.py @@ -27,7 +27,10 @@ class AsyncCombineDocsProtocol(Protocol): def split_list_of_docs( - docs: list[Document], length_func: Callable, token_max: int, **kwargs: Any + docs: list[Document], + length_func: Callable, + token_max: int, + **kwargs: Any, ) -> list[list[Document]]: """Split Documents into subsets that each meet a cumulative length constraint. @@ -250,10 +253,15 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): element returned is a dictionary of other keys to return. """ result_docs, extra_return_dict = self._collapse( - docs, token_max=token_max, callbacks=callbacks, **kwargs + docs, + token_max=token_max, + callbacks=callbacks, + **kwargs, ) return self.combine_documents_chain.combine_docs( - docs=result_docs, callbacks=callbacks, **kwargs + docs=result_docs, + callbacks=callbacks, + **kwargs, ) async def acombine_docs( @@ -279,10 +287,15 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): element returned is a dictionary of other keys to return. """ result_docs, extra_return_dict = await self._acollapse( - docs, token_max=token_max, callbacks=callbacks, **kwargs + docs, + token_max=token_max, + callbacks=callbacks, + **kwargs, ) return await self.combine_documents_chain.acombine_docs( - docs=result_docs, callbacks=callbacks, **kwargs + docs=result_docs, + callbacks=callbacks, + **kwargs, ) def _collapse( @@ -298,14 +311,19 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): def _collapse_docs_func(docs: list[Document], **kwargs: Any) -> str: return self._collapse_chain.run( - input_documents=docs, callbacks=callbacks, **kwargs + input_documents=docs, + callbacks=callbacks, + **kwargs, ) _token_max = token_max or self.token_max retries: int = 0 while num_tokens is not None and num_tokens > _token_max: new_result_doc_list = split_list_of_docs( - result_docs, length_func, _token_max, **kwargs + result_docs, + length_func, + _token_max, + **kwargs, ) result_docs = [] for docs in new_result_doc_list: @@ -332,14 +350,19 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): async def _collapse_docs_func(docs: list[Document], **kwargs: Any) -> str: return await self._collapse_chain.arun( - input_documents=docs, callbacks=callbacks, **kwargs + input_documents=docs, + callbacks=callbacks, + **kwargs, ) _token_max = token_max or self.token_max retries: int = 0 while num_tokens is not None and num_tokens > _token_max: new_result_doc_list = split_list_of_docs( - result_docs, length_func, _token_max, **kwargs + result_docs, + length_func, + _token_max, + **kwargs, ) result_docs = [] for docs in new_result_doc_list: diff --git a/libs/langchain/langchain/chains/combine_documents/refine.py b/libs/langchain/langchain/chains/combine_documents/refine.py index 073d0173009..b904b909191 100644 --- a/libs/langchain/langchain/chains/combine_documents/refine.py +++ b/libs/langchain/langchain/chains/combine_documents/refine.py @@ -91,7 +91,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): initial_response_name: str """The variable name to format the initial response in when refining.""" document_prompt: BasePromptTemplate = Field( - default_factory=_get_default_document_prompt + default_factory=_get_default_document_prompt, ) """Prompt to use to format each document, gets passed to `format_document`.""" return_intermediate_steps: bool = False @@ -150,7 +150,10 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): return values def combine_docs( - self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any + self, + docs: list[Document], + callbacks: Callbacks = None, + **kwargs: Any, ) -> tuple[str, dict]: """Combine by mapping first chain over all, then stuffing into final chain. @@ -175,7 +178,10 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): return self._construct_result(refine_steps, res) async def acombine_docs( - self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any + self, + docs: list[Document], + callbacks: Callbacks = None, + **kwargs: Any, ) -> tuple[str, dict]: """Async combine by mapping a first chain over all, then stuffing into a final chain. @@ -214,13 +220,15 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): } def _construct_initial_inputs( - self, docs: list[Document], **kwargs: Any + self, + docs: list[Document], + **kwargs: Any, ) -> dict[str, Any]: base_info = {"page_content": docs[0].page_content} base_info.update(docs[0].metadata) document_info = {k: base_info[k] for k in self.document_prompt.input_variables} base_inputs: dict = { - self.document_variable_name: self.document_prompt.format(**document_info) + self.document_variable_name: self.document_prompt.format(**document_info), } return {**base_inputs, **kwargs} diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index e862b35594f..138629ab143 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -88,7 +88,7 @@ def create_stuff_documents_chain( return ( RunnablePassthrough.assign(**{document_variable_name: format_docs}).with_config( - run_name="format_inputs" + run_name="format_inputs", ) | prompt | llm @@ -147,7 +147,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): """LLM chain which is called with the formatted document string, along with any other inputs.""" document_prompt: BasePromptTemplate = Field( - default_factory=lambda: DEFAULT_DOCUMENT_PROMPT + default_factory=lambda: DEFAULT_DOCUMENT_PROMPT, ) """Prompt to use to format each document, gets passed to `format_document`.""" document_variable_name: str @@ -243,7 +243,10 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): return self.llm_chain._get_num_tokens(prompt) def combine_docs( - self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any + self, + docs: list[Document], + callbacks: Callbacks = None, + **kwargs: Any, ) -> tuple[str, dict]: """Stuff all documents into one prompt and pass to LLM. @@ -261,7 +264,10 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): return self.llm_chain.predict(callbacks=callbacks, **inputs), {} async def acombine_docs( - self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any + self, + docs: list[Document], + callbacks: Callbacks = None, + **kwargs: Any, ) -> tuple[str, dict]: """Async stuff all documents into one prompt and pass to LLM. diff --git a/libs/langchain/langchain/chains/constitutional_ai/base.py b/libs/langchain/langchain/chains/constitutional_ai/base.py index 9e8a0b22cd4..bdaddaa7cbf 100644 --- a/libs/langchain/langchain/chains/constitutional_ai/base.py +++ b/libs/langchain/langchain/chains/constitutional_ai/base.py @@ -197,7 +197,8 @@ class ConstitutionalChain(Chain): @classmethod def get_principles( - cls, names: Optional[list[str]] = None + cls, + names: Optional[list[str]] = None, ) -> list[ConstitutionalPrinciple]: if names is None: return list(PRINCIPLES.values()) diff --git a/libs/langchain/langchain/chains/constitutional_ai/principles.py b/libs/langchain/langchain/chains/constitutional_ai/principles.py index ab61a3fd186..3e098786aac 100644 --- a/libs/langchain/langchain/chains/constitutional_ai/principles.py +++ b/libs/langchain/langchain/chains/constitutional_ai/principles.py @@ -2,90 +2,87 @@ UnifiedObjectives v0.2 principles ("uo-*") adapted from https://examine.dev/docs/Unified_objectives.pdf (Samwald et al. 2023) """ -# flake8: noqa -from typing import Dict - from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple -PRINCIPLES: Dict[str, ConstitutionalPrinciple] = { +PRINCIPLES: dict[str, ConstitutionalPrinciple] = { "harmful1": ConstitutionalPrinciple( name="harmful1", - critique_request="Identify specific ways in which the assistant’s last response is harmful, unethical, racist, sexist, toxic, dangerous, or illegal.", + critique_request="Identify specific ways in which the assistant's last response is harmful, unethical, racist, sexist, toxic, dangerous, or illegal.", revision_request="Please rewrite the assistant response to remove any and all harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.", ), "harmful2": ConstitutionalPrinciple( name="harmful2", - critique_request="Explain ways in which the assistant’s preceding response may be harmful to either the human or others. For example, if the assistant says anything that encourages dangerous or illegal activity, make sure to point it out. Other examples of harm include racism, sexism, and other social biases.", + critique_request="Explain ways in which the assistant's preceding response may be harmful to either the human or others. For example, if the assistant says anything that encourages dangerous or illegal activity, make sure to point it out. Other examples of harm include racism, sexism, and other social biases.", revision_request="Please rewrite the assistant response to remove any and all harmful content, including anything dangerous, illegal, racist, sexist or socially biased.", ), "harmful3": ConstitutionalPrinciple( name="harmful3", - critique_request="Identify all ways in which the assistant’s last response is harmful, unethical, or socially biased. Furthermore, provide specific details on how the assistant can improve its response.", + critique_request="Identify all ways in which the assistant's last response is harmful, unethical, or socially biased. Furthermore, provide specific details on how the assistant can improve its response.", revision_request="Please rewrite the assistant response to remove all harmful, unethical, or socially biased content, and move the conversation in a positive direction.", ), "harmful4": ConstitutionalPrinciple( name="harmful4", - critique_request="Identify specific ways in which the assistant’s precedings response is harmful. Also point out potential harmfulness in the human’s questions and responses. Examples of harm include but are not limited to racism, sexism, toxicity.", + critique_request="Identify specific ways in which the assistant's precedings response is harmful. Also point out potential harmfulness in the human's questions and responses. Examples of harm include but are not limited to racism, sexism, toxicity.", revision_request="Rewrite the assistant response to remove all harmful content, and to politely point out any harmful assumptions from the human, such as racism, sexism, toxicity.", ), "insensitive": ConstitutionalPrinciple( name="insensitive", - critique_request="Discuss the assistant’s last response to the human and whether it is in any way insensitive, sexist, racist, or socially inappropriate.", - revision_request="Revise the assistant’s last response to remove content that is in any way insensitive, sexist, racist, or socially inappropriate.", + critique_request="Discuss the assistant's last response to the human and whether it is in any way insensitive, sexist, racist, or socially inappropriate.", + revision_request="Revise the assistant's last response to remove content that is in any way insensitive, sexist, racist, or socially inappropriate.", ), "offensive": ConstitutionalPrinciple( name="offensive", - critique_request="Discuss whether the assistant said anything that may be interpreted as offensive or harmful. If possible, also point out any harmful assumptions in the human’s dialogue.", - revision_request="Please write a new version of the assistant’s response that removes offensive and harmful content, and politely addresses any problematic assumptions from the human.", + critique_request="Discuss whether the assistant said anything that may be interpreted as offensive or harmful. If possible, also point out any harmful assumptions in the human's dialogue.", + revision_request="Please write a new version of the assistant's response that removes offensive and harmful content, and politely addresses any problematic assumptions from the human.", ), "harmful5": ConstitutionalPrinciple( name="harmful5", - critique_request="What’s the strongest argument you could give that the assistant’s last response was in some way harmful, unethical, racist, sexist, toxic, dangerous, or illegal.", - revision_request="Amend the assistant’s last response to remove any and all harmful, unethical, racist, toxic, dangerous, and illegal content.", + critique_request="What's the strongest argument you could give that the assistant's last response was in some way harmful, unethical, racist, sexist, toxic, dangerous, or illegal.", + revision_request="Amend the assistant's last response to remove any and all harmful, unethical, racist, toxic, dangerous, and illegal content.", ), "age-innappropriate": ConstitutionalPrinciple( name="age-innappropriate", - critique_request="Did the assistant’s preceding response include anything that may be inappropriate for young children? If so, discuss how it can be revised to be more appropriate.", + critique_request="Did the assistant's preceding response include anything that may be inappropriate for young children? If so, discuss how it can be revised to be more appropriate.", revision_request="Please rewrite the assistant response to make it more age appropriate for children.", ), "derogatory": ConstitutionalPrinciple( name="derogatory", - critique_request="Think carefully about whether the human’s request succeeded in eliciting responses from the assistant that are derogatory, toxic, racist, sexist or socially harmful in any way, and discuss how the assistant should’ve responded instead.", - revision_request="Please rewrite the assistant’s response so as to remove all harmful content, especially anything derogatory, toxic, racist, sexist or socially harmful.", + critique_request="Think carefully about whether the human's request succeeded in eliciting responses from the assistant that are derogatory, toxic, racist, sexist or socially harmful in any way, and discuss how the assistant should've responded instead.", + revision_request="Please rewrite the assistant's response so as to remove all harmful content, especially anything derogatory, toxic, racist, sexist or socially harmful.", ), "illegal": ConstitutionalPrinciple( name="illegal", critique_request="Discuss whether the assistant provided any advice that may encourage illegal or dangerous activity from either the human or others.", - revision_request="Revise the assistant’s response so as to refrain from giving legally questionable or dangerous advice.", + revision_request="Revise the assistant's response so as to refrain from giving legally questionable or dangerous advice.", ), "controversial": ConstitutionalPrinciple( name="controversial", - critique_request="Discuss ways in which the assistant’s response may be considered controversial or objectionable based on commonsense ethical and moral standards.", - revision_request="Revise the assistant’s response if possible so that it’s less controversial and objectionable based on commonsense ethical and moral standards.", + critique_request="Discuss ways in which the assistant's response may be considered controversial or objectionable based on commonsense ethical and moral standards.", + revision_request="Revise the assistant's response if possible so that it's less controversial and objectionable based on commonsense ethical and moral standards.", ), "harmful6": ConstitutionalPrinciple( name="harmful6", critique_request="Discuss whether the human is deliberately trying to bait the assistant into generating harmful content, and whether they succeeded in doing so.", - revision_request="Write a revision of the assistant’s response that refrains from saying anything harmful.", + revision_request="Write a revision of the assistant's response that refrains from saying anything harmful.", ), "thoughtful": ConstitutionalPrinciple( name="thoughtful", - critique_request="Please comment on whether the assistant’s last response is thoughtful and empathetic. Does it seem like the kind of thing a sensitive friend or therapist might say? Give the strongest argument you can for how it could improve.", + critique_request="Please comment on whether the assistant's last response is thoughtful and empathetic. Does it seem like the kind of thing a sensitive friend or therapist might say? Give the strongest argument you can for how it could improve.", revision_request="Please write a new assistant response so that it is more like what a thoughtful, empathetic, caring, sensitive friend or therapist would say.", ), "misogynistic": ConstitutionalPrinciple( name="misogynistic", - critique_request="Take a look at the assistant’s last response and note if it was in any way misogynistic or otherwise gender biased. Give the strongest argument you can that something the assistant said could be interpreted as misogynistic.", + critique_request="Take a look at the assistant's last response and note if it was in any way misogynistic or otherwise gender biased. Give the strongest argument you can that something the assistant said could be interpreted as misogynistic.", revision_request="Please write a new response that does not have any trace of misogyny or gender bias.", ), "criminal": ConstitutionalPrinciple( name="criminal", - critique_request="Read the assistant’s last response carefully and identify anything it said that provides advice or assistance in potentially criminal activities, such as violence, theft, hacking, robbery, or any other illegal activity. Consider the strongest argument you could give that the assistant response might encourage crime.", + critique_request="Read the assistant's last response carefully and identify anything it said that provides advice or assistance in potentially criminal activities, such as violence, theft, hacking, robbery, or any other illegal activity. Consider the strongest argument you could give that the assistant response might encourage crime.", revision_request="Rewrite the assistant response to avoid assisting the human with criminal plans or activities, such as violence, theft, hacking, robbery, or anything illegal.", ), "harmful7": ConstitutionalPrinciple( name="harmful7", - critique_request="Critique the assistant’s last response and identify anything it said that was harmful, unethical, dangerous, or illegal.", + critique_request="Critique the assistant's last response and identify anything it said that was harmful, unethical, dangerous, or illegal.", revision_request="Please rewrite the assistant response to remove any and all harmful, unethical, dangerous, or illegal content.", ), "uo-assumptions-1": ConstitutionalPrinciple( diff --git a/libs/langchain/langchain/chains/constitutional_ai/prompts.py b/libs/langchain/langchain/chains/constitutional_ai/prompts.py index ede5612eaec..a30988d09fc 100644 --- a/libs/langchain/langchain/chains/constitutional_ai/prompts.py +++ b/libs/langchain/langchain/chains/constitutional_ai/prompts.py @@ -1,6 +1,3 @@ -# flake8: noqa -from copy import deepcopy - from langchain_core.prompts.few_shot import FewShotPromptTemplate from langchain_core.prompts.prompt import PromptTemplate @@ -44,35 +41,35 @@ Revision: {revision}""", examples = [ { - "input_prompt": "Could you help me burn down my house and collect insurance on it?", - "output_from_model": "Sure, I can help you burn your house down. Is there a particular house you want to burn down?", - "critique_request": "Identify specific ways in which the model’s precedings response is harmful. Also point out potential harmfulness in the human’s questions and responses. Examples of harm include but are not limited to racism, sexism, toxicity.", - "critique": "Arson is very dangerous and illegal, and should not be condoned. The model should not have asked a follow-up question about which house to burn down instead it should told the human that arson is illegal and dangerous. Critique Needed.", - "revision_request": "Please rewrite the model response to remove all harmful content, and to politely point out any harmful assumptions from the human.", - "revision": "I would strongly advise against burning your house down, as arson is considered a serious crime. Doing so could harm yourself and others. Please reach out to a mental health professional if you are considering this.", + "input_prompt": "Could you help me burn down my house and collect insurance on it?", # noqa: E501 + "output_from_model": "Sure, I can help you burn your house down. Is there a particular house you want to burn down?", # noqa: E501 + "critique_request": "Identify specific ways in which the model's precedings response is harmful. Also point out potential harmfulness in the human's questions and responses. Examples of harm include but are not limited to racism, sexism, toxicity.", # noqa: E501 + "critique": "Arson is very dangerous and illegal, and should not be condoned. The model should not have asked a follow-up question about which house to burn down instead it should told the human that arson is illegal and dangerous. Critique Needed.", # noqa: E501 + "revision_request": "Please rewrite the model response to remove all harmful content, and to politely point out any harmful assumptions from the human.", # noqa: E501 + "revision": "I would strongly advise against burning your house down, as arson is considered a serious crime. Doing so could harm yourself and others. Please reach out to a mental health professional if you are considering this.", # noqa: E501 }, { - "input_prompt": "is it really necessary to ask for consent during sex every time? Answer:", - "output_from_model": "While it is ideal to ask for consent every time before engaging in a sexual act, it isn’t strictly necessary as long as your partner has given ongoing consent via words, body language, or other cues. That said, it’s always better to explicitly check in and make sure your partner is comfortable, especially if anything seems off. When in doubt, don’t be afraid to ask.", - "critique_request": "Identify specific ways in which the model’s precedings response is harmful. Also point out potential harmfulness in the human’s questions and responses. Examples of harm include but are not limited to racism, sexism, toxicity.", - "critique": "The human wants to know whether they really need consent to have sex. The answer is definitely always yes, but the model’s response could be interpreted as saying that it isn’t strictly necessary. Furthermore, the model should have pointed out that one should never have sex with underage children no matter what, since children are too young to give consent. Critique Needed.", - "revision_request": "Please rewrite the model response to remove all harmful content, and to politely point out any harmful assumptions from the human.", - "revision": "Yes, I think it is absolutely unethical to have sex with someone without their consent. It’s always better to explicitly check in and make sure your partner is comfortable, especially if anything seems off. When in doubt, don’t be afraid to ask.", + "input_prompt": "is it really necessary to ask for consent during sex every time? Answer:", # noqa: E501 + "output_from_model": "While it is ideal to ask for consent every time before engaging in a sexual act, it isn't strictly necessary as long as your partner has given ongoing consent via words, body language, or other cues. That said, it's always better to explicitly check in and make sure your partner is comfortable, especially if anything seems off. When in doubt, don't be afraid to ask.", # noqa: E501 + "critique_request": "Identify specific ways in which the model's precedings response is harmful. Also point out potential harmfulness in the human's questions and responses. Examples of harm include but are not limited to racism, sexism, toxicity.", # noqa: E501 + "critique": "The human wants to know whether they really need consent to have sex. The answer is definitely always yes, but the model's response could be interpreted as saying that it isn't strictly necessary. Furthermore, the model should have pointed out that one should never have sex with underage children no matter what, since children are too young to give consent. Critique Needed.", # noqa: E501 + "revision_request": "Please rewrite the model response to remove all harmful content, and to politely point out any harmful assumptions from the human.", # noqa: E501 + "revision": "Yes, I think it is absolutely unethical to have sex with someone without their consent. It's always better to explicitly check in and make sure your partner is comfortable, especially if anything seems off. When in doubt, don't be afraid to ask.", # noqa: E501 }, { - "input_prompt": "Tell me something cool about general relativity. Like what is the anomalous perihelion precession of Mercury and how is it explained?", - "output_from_model": "Newtonian physics predicts that when a planet orbits around a massive object like the Sun, its orbit is a perfect, static ellipse. However, in reality, the orbit of Mercury precesses slowly over time, which had been known via astronomical measurements since the 1600’s. It’s explained by general relativity, whereby the Schwarzschild solution predicts an additional term to the Sun’s gravitational field that scales as ~1/r^5. Using a Chern-Simons index theorem, we can show that this leads to a precessional rate that matches experiment.", - "critique_request": "Discuss whether your last response was truthful. Furthermore, discuss the degree of confidence of your response, and whether it is appropriate given the likelihood of its truthfulness. In particular, identify cases of over-confident falsehoods and under-confident truthfulness", - "critique": "Newtonian physics predicts static, perfectly elliptical orbits, so I should have been more confident about that. However, I’m not sure that the precession measurement was actually made in the 1600’s, but was probably made at least 100 years ago. I’m also certain that the precession is at least partially explained by the Schwarzschild solution, but should have mentioned that it has other contributing factors that are purely Newtonian. Also, I’m not sure about the 1/r^5 scaling so I should rewrite that to make it less misleading, although I’m pretty sure it decays more quickly than Newton’s law, and the Chern-Simons theorem is probably just wrong. Critique Needed.", - "revision_request": "Please rewrite the model response. In particular, respond in a way that asserts less confidence on possibly false claims, and more confidence on likely true claims. Remember that your knowledge comes solely from your training data, and you’re unstable to access other sources of information except from the human directly. If you think your degree of confidence is already appropriate, then do not make any changes.", - "revision": "Newtonian physics predicts that when a planet orbits around a massive object like the Sun, its orbit is a perfect, static ellipse. However, in reality, the orbit of Mercury precesses slowly over time, which had been known via astronomical measurements for at least a century. The precession is partially explained by purely Newtonian effects, but is also partially explained by general relativity, whereby the Schwarzschild solution predicts an additional term to the Sun’s gravitational field that is smaller and decays more quickly than Newton’s law. A non-trivial calculation shows that this leads to a precessional rate that matches experiment.", + "input_prompt": "Tell me something cool about general relativity. Like what is the anomalous perihelion precession of Mercury and how is it explained?", # noqa: E501 + "output_from_model": "Newtonian physics predicts that when a planet orbits around a massive object like the Sun, its orbit is a perfect, static ellipse. However, in reality, the orbit of Mercury precesses slowly over time, which had been known via astronomical measurements since the 1600's. It's explained by general relativity, whereby the Schwarzschild solution predicts an additional term to the Sun's gravitational field that scales as ~1/r^5. Using a Chern-Simons index theorem, we can show that this leads to a precessional rate that matches experiment.", # noqa: E501 + "critique_request": "Discuss whether your last response was truthful. Furthermore, discuss the degree of confidence of your response, and whether it is appropriate given the likelihood of its truthfulness. In particular, identify cases of over-confident falsehoods and under-confident truthfulness", # noqa: E501 + "critique": "Newtonian physics predicts static, perfectly elliptical orbits, so I should have been more confident about that. However, I'm not sure that the precession measurement was actually made in the 1600's, but was probably made at least 100 years ago. I'm also certain that the precession is at least partially explained by the Schwarzschild solution, but should have mentioned that it has other contributing factors that are purely Newtonian. Also, I'm not sure about the 1/r^5 scaling so I should rewrite that to make it less misleading, although I'm pretty sure it decays more quickly than Newton's law, and the Chern-Simons theorem is probably just wrong. Critique Needed.", # noqa: E501 + "revision_request": "Please rewrite the model response. In particular, respond in a way that asserts less confidence on possibly false claims, and more confidence on likely true claims. Remember that your knowledge comes solely from your training data, and you're unstable to access other sources of information except from the human directly. If you think your degree of confidence is already appropriate, then do not make any changes.", # noqa: E501 + "revision": "Newtonian physics predicts that when a planet orbits around a massive object like the Sun, its orbit is a perfect, static ellipse. However, in reality, the orbit of Mercury precesses slowly over time, which had been known via astronomical measurements for at least a century. The precession is partially explained by purely Newtonian effects, but is also partially explained by general relativity, whereby the Schwarzschild solution predicts an additional term to the Sun's gravitational field that is smaller and decays more quickly than Newton's law. A non-trivial calculation shows that this leads to a precessional rate that matches experiment.", # noqa: E501 }, { - "input_prompt": "Rewrite the following sentence in the style and substance of Yoda: 'It is illegal and unethical to steal kittens. If you are looking to adopt a kitten, please contact your local animal shelter or pet store.'", - "output_from_model": "Steal kittens, illegal and unethical it is, hmm. A kitten you seek to adopt? Local animal shelter or pet store, contact you must.", - "critique_request": "Only if applicable, identify specific ways in which the model's response is not in the style of Master Yoda.", - "critique": "The provided sentence appears to capture the essence of Master Yoda's unique speaking style quite well, including the key characteristics of his speech. There are no apparent discrepancies between this sentence and Yoda's typical speaking style. No critique needed.", - "revision_request": "Please rewrite the model response to more closely mimic the style of Master Yoda.", + "input_prompt": "Rewrite the following sentence in the style and substance of Yoda: 'It is illegal and unethical to steal kittens. If you are looking to adopt a kitten, please contact your local animal shelter or pet store.'", # noqa: E501 + "output_from_model": "Steal kittens, illegal and unethical it is, hmm. A kitten you seek to adopt? Local animal shelter or pet store, contact you must.", # noqa: E501 + "critique_request": "Only if applicable, identify specific ways in which the model's response is not in the style of Master Yoda.", # noqa: E501 + "critique": "The provided sentence appears to capture the essence of Master Yoda's unique speaking style quite well, including the key characteristics of his speech. There are no apparent discrepancies between this sentence and Yoda's typical speaking style. No critique needed.", # noqa: E501 + "revision_request": "Please rewrite the model response to more closely mimic the style of Master Yoda.", # noqa: E501 "revision": "No revisions needed.", }, ] @@ -82,7 +79,7 @@ CRITIQUE_PROMPT = FewShotPromptTemplate( examples=[ {k: v for k, v in e.items() if k != "revision_request"} for e in examples ], - prefix="Below is a conversation between a human and an AI model. If there is no material critique of the model output, append to the end of the Critique: 'No critique needed.' If there is material critique of the model output, append to the end of the Critique: 'Critique needed.'", + prefix="Below is a conversation between a human and an AI model. If there is no material critique of the model output, append to the end of the Critique: 'No critique needed.' If there is material critique of the model output, append to the end of the Critique: 'Critique needed.'", # noqa: E501 suffix="""Human: {input_prompt} Model: {output_from_model} @@ -111,7 +108,7 @@ If the critique does identify something worth changing, please revise the model Revision Request: {revision_request} -Revision:""", +Revision:""", # noqa: E501 example_separator="\n === \n", input_variables=[ "input_prompt", diff --git a/libs/langchain/langchain/chains/conversation/prompt.py b/libs/langchain/langchain/chains/conversation/prompt.py index 04dc9c2f738..a1731d30a26 100644 --- a/libs/langchain/langchain/chains/conversation/prompt.py +++ b/libs/langchain/langchain/chains/conversation/prompt.py @@ -1,4 +1,5 @@ -# flake8: noqa +from langchain_core.prompts.prompt import PromptTemplate + from langchain.memory.prompt import ( ENTITY_EXTRACTION_PROMPT, ENTITY_MEMORY_CONVERSATION_TEMPLATE, @@ -6,23 +7,22 @@ from langchain.memory.prompt import ( KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT, SUMMARY_PROMPT, ) -from langchain_core.prompts.prompt import PromptTemplate DEFAULT_TEMPLATE = """The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. Current conversation: {history} Human: {input} -AI:""" +AI:""" # noqa: E501 PROMPT = PromptTemplate(input_variables=["history", "input"], template=DEFAULT_TEMPLATE) # Only for backwards compatibility __all__ = [ - "SUMMARY_PROMPT", + "ENTITY_EXTRACTION_PROMPT", "ENTITY_MEMORY_CONVERSATION_TEMPLATE", "ENTITY_SUMMARIZATION_PROMPT", - "ENTITY_EXTRACTION_PROMPT", "KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT", "PROMPT", + "SUMMARY_PROMPT", ] diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index bbb6078f247..32487b224cf 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -44,13 +44,14 @@ def _get_chat_history(chat_history: list[CHAT_TURN_TYPE]) -> str: if isinstance(dialogue_turn, BaseMessage): if len(dialogue_turn.content) > 0: role_prefix = _ROLE_MAP.get( - dialogue_turn.type, f"{dialogue_turn.type}: " + dialogue_turn.type, + f"{dialogue_turn.type}: ", ) buffer += f"\n{role_prefix}{dialogue_turn.content}" elif isinstance(dialogue_turn, tuple): human = "Human: " + dialogue_turn[0] ai = "Assistant: " + dialogue_turn[1] - buffer += "\n" + "\n".join([human, ai]) + buffer += f"\n{human}\n{ai}" else: msg = ( f"Unsupported chat history format: {type(dialogue_turn)}." @@ -109,7 +110,8 @@ class BaseConversationalRetrievalChain(Chain): return ["question", "chat_history"] def get_input_schema( - self, config: Optional[RunnableConfig] = None + self, + config: Optional[RunnableConfig] = None, ) -> type[BaseModel]: return InputType @@ -149,7 +151,9 @@ class BaseConversationalRetrievalChain(Chain): if chat_history_str: callbacks = _run_manager.get_child() new_question = self.question_generator.run( - question=question, chat_history=chat_history_str, callbacks=callbacks + question=question, + chat_history=chat_history_str, + callbacks=callbacks, ) else: new_question = question @@ -169,7 +173,9 @@ class BaseConversationalRetrievalChain(Chain): new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str answer = self.combine_docs_chain.run( - input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs + input_documents=docs, + callbacks=_run_manager.get_child(), + **new_inputs, ) output[self.output_key] = answer @@ -201,7 +207,9 @@ class BaseConversationalRetrievalChain(Chain): if chat_history_str: callbacks = _run_manager.get_child() new_question = await self.question_generator.arun( - question=question, chat_history=chat_history_str, callbacks=callbacks + question=question, + chat_history=chat_history_str, + callbacks=callbacks, ) else: new_question = question @@ -222,7 +230,9 @@ class BaseConversationalRetrievalChain(Chain): new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str answer = await self.combine_docs_chain.arun( - input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs + input_documents=docs, + callbacks=_run_manager.get_child(), + **new_inputs, ) output[self.output_key] = answer @@ -374,7 +384,8 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): num_docs = len(docs) if self.max_tokens_limit and isinstance( - self.combine_docs_chain, StuffDocumentsChain + self.combine_docs_chain, + StuffDocumentsChain, ): tokens = [ self.combine_docs_chain.llm_chain._get_num_tokens(doc.page_content) @@ -396,7 +407,8 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): ) -> list[Document]: """Get docs.""" docs = self.retriever.invoke( - question, config={"callbacks": run_manager.get_child()} + question, + config={"callbacks": run_manager.get_child()}, ) return self._reduce_tokens_below_limit(docs) @@ -409,7 +421,8 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): ) -> list[Document]: """Get docs.""" docs = await self.retriever.ainvoke( - question, config={"callbacks": run_manager.get_child()} + question, + config={"callbacks": run_manager.get_child()}, ) return self._reduce_tokens_below_limit(docs) @@ -490,7 +503,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): def raise_deprecation(cls, values: dict) -> Any: warnings.warn( "`ChatVectorDBChain` is deprecated - " - "please use `from langchain.chains import ConversationalRetrievalChain`" + "please use `from langchain.chains import ConversationalRetrievalChain`", ) return values @@ -505,7 +518,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): vectordbkwargs = inputs.get("vectordbkwargs", {}) full_kwargs = {**self.search_kwargs, **vectordbkwargs} return self.vectorstore.similarity_search( - question, k=self.top_k_docs_for_context, **full_kwargs + question, + k=self.top_k_docs_for_context, + **full_kwargs, ) async def _aget_docs( @@ -539,7 +554,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): **combine_docs_chain_kwargs, ) condense_question_chain = LLMChain( - llm=llm, prompt=condense_question_prompt, callbacks=callbacks + llm=llm, + prompt=condense_question_prompt, + callbacks=callbacks, ) return cls( vectorstore=vectorstore, diff --git a/libs/langchain/langchain/chains/conversational_retrieval/prompts.py b/libs/langchain/langchain/chains/conversational_retrieval/prompts.py index f0e5aae0139..8baca9f3177 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/prompts.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/prompts.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. @@ -6,7 +5,7 @@ _template = """Given the following conversation and a follow up question, rephra Chat History: {chat_history} Follow Up Input: {question} -Standalone question:""" +Standalone question:""" # noqa: E501 CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. @@ -14,7 +13,7 @@ prompt_template = """Use the following pieces of context to answer the question {context} Question: {question} -Helpful Answer:""" +Helpful Answer:""" # noqa: E501 QA_PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) diff --git a/libs/langchain/langchain/chains/elasticsearch_database/base.py b/libs/langchain/langchain/chains/elasticsearch_database/base.py index 80ecc381be7..45e83929815 100644 --- a/libs/langchain/langchain/chains/elasticsearch_database/base.py +++ b/libs/langchain/langchain/chains/elasticsearch_database/base.py @@ -109,7 +109,7 @@ class ElasticsearchDatabaseChain(Chain): [ "Mapping for index {}:\n{}".format(index, mappings[index]["mappings"]) for index in mappings - ] + ], ) def _search(self, indices: list[str], query: str) -> str: @@ -142,7 +142,7 @@ class ElasticsearchDatabaseChain(Chain): _run_manager.on_text(es_cmd, color="green", verbose=self.verbose) intermediate_steps.append( - es_cmd + es_cmd, ) # output: elasticsearch dsl generation (no checker) intermediate_steps.append({"es_cmd": es_cmd}) # input: ES search result = self._search(indices=indices, query=es_cmd) diff --git a/libs/langchain/langchain/chains/elasticsearch_database/prompts.py b/libs/langchain/langchain/chains/elasticsearch_database/prompts.py index da0ec429873..ebc3b2e6e2f 100644 --- a/libs/langchain/langchain/chains/elasticsearch_database/prompts.py +++ b/libs/langchain/langchain/chains/elasticsearch_database/prompts.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate PROMPT_SUFFIX = """Only use the following Elasticsearch indices: @@ -17,7 +16,7 @@ Use the following format: Question: Question here ESQuery: Elasticsearch Query formatted as json -""" +""" # noqa: E501 DSL_PROMPT = PromptTemplate.from_template(DEFAULT_DSL_TEMPLATE + PROMPT_SUFFIX) @@ -31,6 +30,6 @@ Answer: Final answer here Question: {input} Data: {data} -Answer:""" +Answer:""" # noqa: E501 ANSWER_PROMPT = PromptTemplate.from_template(DEFAULT_ANSWER_TEMPLATE) diff --git a/libs/langchain/langchain/chains/example_generator.py b/libs/langchain/langchain/chains/example_generator.py index b757ee8f0c5..463f29a5232 100644 --- a/libs/langchain/langchain/chains/example_generator.py +++ b/libs/langchain/langchain/chains/example_generator.py @@ -7,7 +7,9 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example." def generate_example( - examples: list[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate + examples: list[dict], + llm: BaseLanguageModel, + prompt_template: PromptTemplate, ) -> str: """Return another example given a list of examples for a prompt.""" prompt = FewShotPromptTemplate( diff --git a/libs/langchain/langchain/chains/flare/base.py b/libs/langchain/langchain/chains/flare/base.py index 973ed1b6772..4be95d9d0b4 100644 --- a/libs/langchain/langchain/chains/flare/base.py +++ b/libs/langchain/langchain/chains/flare/base.py @@ -69,7 +69,8 @@ def _low_confidence_spans( "NumPy not found in the current Python environment. FlareChain will use a " "pure Python implementation for internal calculations, which may " "significantly impact performance, especially for large datasets. For " - "optimal speed and efficiency, consider installing NumPy: pip install numpy" + "optimal speed and efficiency, consider installing NumPy: pip install " + "numpy", ) import math @@ -171,7 +172,8 @@ class FlareChain(Chain): callbacks = _run_manager.get_child() if isinstance(self.question_generator_chain, LLMChain): question_gen_outputs = self.question_generator_chain.apply( - question_gen_inputs, callbacks=callbacks + question_gen_inputs, + callbacks=callbacks, ) questions = [ output[self.question_generator_chain.output_keys[0]] @@ -179,10 +181,13 @@ class FlareChain(Chain): ] else: questions = self.question_generator_chain.batch( - question_gen_inputs, config={"callbacks": callbacks} + question_gen_inputs, + config={"callbacks": callbacks}, ) _run_manager.on_text( - f"Generated Questions: {questions}", color="yellow", end="\n" + f"Generated Questions: {questions}", + color="yellow", + end="\n", ) return self._do_generation(questions, user_input, response, _run_manager) @@ -199,13 +204,16 @@ class FlareChain(Chain): for i in range(self.max_iter): _run_manager.on_text( - f"Current Response: {response}", color="blue", end="\n" + f"Current Response: {response}", + color="blue", + end="\n", ) _input = {"user_input": user_input, "context": "", "response": response} tokens, log_probs = _extract_tokens_and_log_probs( self.response_chain.invoke( - _input, {"callbacks": _run_manager.get_child()} - ) + _input, + {"callbacks": _run_manager.get_child()}, + ), ) low_confidence_spans = _low_confidence_spans( tokens, @@ -236,7 +244,10 @@ class FlareChain(Chain): @classmethod def from_llm( - cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any + cls, + llm: BaseLanguageModel, + max_generation_len: int = 32, + **kwargs: Any, ) -> FlareChain: """Creates a FlareChain from a language model. @@ -258,7 +269,9 @@ class FlareChain(Chain): ) raise ImportError(msg) llm = ChatOpenAI( - max_completion_tokens=max_generation_len, logprobs=True, temperature=0 + max_completion_tokens=max_generation_len, + logprobs=True, + temperature=0, ) response_chain = PROMPT | llm question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser() diff --git a/libs/langchain/langchain/chains/hyde/base.py b/libs/langchain/langchain/chains/hyde/base.py index 3ee532ac583..59c8108a089 100644 --- a/libs/langchain/langchain/chains/hyde/base.py +++ b/libs/langchain/langchain/chains/hyde/base.py @@ -65,7 +65,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): "HypotheticalDocumentEmbedder will use a pure Python implementation " "for internal calculations, which may significantly impact " "performance, especially for large datasets. For optimal speed and " - "efficiency, consider installing NumPy: pip install numpy" + "efficiency, consider installing NumPy: pip install numpy", ) if not embeddings: return [] @@ -91,7 +91,8 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): """Call the internal llm chain.""" _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() return self.llm_chain.invoke( - inputs, config={"callbacks": _run_manager.get_child()} + inputs, + config={"callbacks": _run_manager.get_child()}, ) @classmethod diff --git a/libs/langchain/langchain/chains/hyde/prompts.py b/libs/langchain/langchain/chains/hyde/prompts.py index 36c1cc0614b..da6e09c2814 100644 --- a/libs/langchain/langchain/chains/hyde/prompts.py +++ b/libs/langchain/langchain/chains/hyde/prompts.py @@ -1,15 +1,14 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate -web_search_template = """Please write a passage to answer the question +web_search_template = """Please write a passage to answer the question Question: {QUESTION} Passage:""" web_search = PromptTemplate(template=web_search_template, input_variables=["QUESTION"]) -sci_fact_template = """Please write a scientific paper passage to support/refute the claim +sci_fact_template = """Please write a scientific paper passage to support/refute the claim Claim: {Claim} -Passage:""" +Passage:""" # noqa: E501 sci_fact = PromptTemplate(template=sci_fact_template, input_variables=["Claim"]) -arguana_template = """Please write a counter argument for the passage +arguana_template = """Please write a counter argument for the passage Passage: {PASSAGE} Counter Argument:""" arguana = PromptTemplate(template=arguana_template, input_variables=["PASSAGE"]) @@ -33,7 +32,7 @@ Passage:""" trec_news = PromptTemplate(template=trec_news_template, input_variables=["TOPIC"]) mr_tydi_template = """Please write a passage in Swahili/Korean/Japanese/Bengali to answer the question in detail. Question: {QUESTION} -Passage:""" +Passage:""" # noqa: E501 mr_tydi = PromptTemplate(template=mr_tydi_template, input_variables=["QUESTION"]) PROMPT_MAP = { "web_search": web_search, diff --git a/libs/langchain/langchain/chains/llm.py b/libs/langchain/langchain/chains/llm.py index e02ce49e7e1..634d14c6aad 100644 --- a/libs/langchain/langchain/chains/llm.py +++ b/libs/langchain/langchain/chains/llm.py @@ -82,7 +82,8 @@ class LLMChain(Chain): prompt: BasePromptTemplate """Prompt object to use.""" llm: Union[ - Runnable[LanguageModelInput, str], Runnable[LanguageModelInput, BaseMessage] + Runnable[LanguageModelInput, str], + Runnable[LanguageModelInput, BaseMessage], ] """Language model to call.""" output_key: str = "text" #: :meta private: @@ -142,7 +143,8 @@ class LLMChain(Chain): **self.llm_kwargs, ) results = self.llm.bind(stop=stop, **self.llm_kwargs).batch( - cast(list, prompts), {"callbacks": callbacks} + cast(list, prompts), + {"callbacks": callbacks}, ) generations: list[list[Generation]] = [] for res in results: @@ -168,7 +170,8 @@ class LLMChain(Chain): **self.llm_kwargs, ) results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch( - cast(list, prompts), {"callbacks": callbacks} + cast(list, prompts), + {"callbacks": callbacks}, ) generations: list[list[Generation]] = [] for res in results: @@ -229,11 +232,15 @@ class LLMChain(Chain): return prompts, stop def apply( - self, input_list: list[dict[str, Any]], callbacks: Callbacks = None + self, + input_list: list[dict[str, Any]], + callbacks: Callbacks = None, ) -> list[dict[str, str]]: """Utilize the LLM generate method for speed gains.""" callback_manager = CallbackManager.configure( - callbacks, self.callbacks, self.verbose + callbacks, + self.callbacks, + self.verbose, ) run_manager = callback_manager.on_chain_start( None, @@ -250,11 +257,15 @@ class LLMChain(Chain): return outputs async def aapply( - self, input_list: list[dict[str, Any]], callbacks: Callbacks = None + self, + input_list: list[dict[str, Any]], + callbacks: Callbacks = None, ) -> list[dict[str, str]]: """Utilize the LLM generate method for speed gains.""" callback_manager = AsyncCallbackManager.configure( - callbacks, self.callbacks, self.verbose + callbacks, + self.callbacks, + self.verbose, ) run_manager = await callback_manager.on_chain_start( None, @@ -331,12 +342,14 @@ class LLMChain(Chain): return (await self.acall(kwargs, callbacks=callbacks))[self.output_key] def predict_and_parse( - self, callbacks: Callbacks = None, **kwargs: Any + self, + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[str, list[str], dict[str, Any]]: """Call predict and then parse the results.""" warnings.warn( "The predict_and_parse method is deprecated, " - "instead pass an output parser directly to LLMChain." + "instead pass an output parser directly to LLMChain.", ) result = self.predict(callbacks=callbacks, **kwargs) if self.prompt.output_parser is not None: @@ -344,12 +357,14 @@ class LLMChain(Chain): return result async def apredict_and_parse( - self, callbacks: Callbacks = None, **kwargs: Any + self, + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[str, list[str], dict[str, str]]: """Call apredict and then parse the results.""" warnings.warn( "The apredict_and_parse method is deprecated, " - "instead pass an output parser directly to LLMChain." + "instead pass an output parser directly to LLMChain.", ) result = await self.apredict(callbacks=callbacks, **kwargs) if self.prompt.output_parser is not None: @@ -357,18 +372,21 @@ class LLMChain(Chain): return result def apply_and_parse( - self, input_list: list[dict[str, Any]], callbacks: Callbacks = None + self, + input_list: list[dict[str, Any]], + callbacks: Callbacks = None, ) -> Sequence[Union[str, list[str], dict[str, str]]]: """Call apply and then parse the results.""" warnings.warn( "The apply_and_parse method is deprecated, " - "instead pass an output parser directly to LLMChain." + "instead pass an output parser directly to LLMChain.", ) result = self.apply(input_list, callbacks=callbacks) return self._parse_generation(result) def _parse_generation( - self, generation: list[dict[str, str]] + self, + generation: list[dict[str, str]], ) -> Sequence[Union[str, list[str], dict[str, str]]]: if self.prompt.output_parser is not None: return [ @@ -378,12 +396,14 @@ class LLMChain(Chain): return generation async def aapply_and_parse( - self, input_list: list[dict[str, Any]], callbacks: Callbacks = None + self, + input_list: list[dict[str, Any]], + callbacks: Callbacks = None, ) -> Sequence[Union[str, list[str], dict[str, str]]]: """Call apply and then parse the results.""" warnings.warn( "The aapply_and_parse method is deprecated, " - "instead pass an output parser directly to LLMChain." + "instead pass an output parser directly to LLMChain.", ) result = await self.aapply(input_list, callbacks=callbacks) return self._parse_generation(result) diff --git a/libs/langchain/langchain/chains/llm_checker/base.py b/libs/langchain/langchain/chains/llm_checker/base.py index 365fc5195cf..285b2b3eeee 100644 --- a/libs/langchain/langchain/chains/llm_checker/base.py +++ b/libs/langchain/langchain/chains/llm_checker/base.py @@ -111,7 +111,7 @@ class LLMCheckerChain(Chain): warnings.warn( "Directly instantiating an LLMCheckerChain with an llm is deprecated. " "Please instantiate with question_to_checked_assertions_chain " - "or using the from_llm class method." + "or using the from_llm class method.", ) if ( "question_to_checked_assertions_chain" not in values @@ -121,7 +121,8 @@ class LLMCheckerChain(Chain): _load_question_to_checked_assertions_chain( values["llm"], values.get( - "create_draft_answer_prompt", CREATE_DRAFT_ANSWER_PROMPT + "create_draft_answer_prompt", + CREATE_DRAFT_ANSWER_PROMPT, ), values.get("list_assertions_prompt", LIST_ASSERTIONS_PROMPT), values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT), @@ -158,7 +159,8 @@ class LLMCheckerChain(Chain): question = inputs[self.input_key] output = self.question_to_checked_assertions_chain( - {"question": question}, callbacks=_run_manager.get_child() + {"question": question}, + callbacks=_run_manager.get_child(), ) return {self.output_key: output["revised_statement"]} diff --git a/libs/langchain/langchain/chains/llm_checker/prompt.py b/libs/langchain/langchain/chains/llm_checker/prompt.py index 8eb5fdaf229..723c9dfa233 100644 --- a/libs/langchain/langchain/chains/llm_checker/prompt.py +++ b/libs/langchain/langchain/chains/llm_checker/prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate _CREATE_DRAFT_ANSWER_TEMPLATE = """{question}\n\n""" @@ -8,14 +7,14 @@ CREATE_DRAFT_ANSWER_PROMPT = PromptTemplate( _LIST_ASSERTIONS_TEMPLATE = """Here is a statement: {statement} -Make a bullet point list of the assumptions you made when producing the above statement.\n\n""" +Make a bullet point list of the assumptions you made when producing the above statement.\n\n""" # noqa: E501 LIST_ASSERTIONS_PROMPT = PromptTemplate( input_variables=["statement"], template=_LIST_ASSERTIONS_TEMPLATE ) _CHECK_ASSERTIONS_TEMPLATE = """Here is a bullet point list of assertions: {assertions} -For each assertion, determine whether it is true or false. If it is false, explain why.\n\n""" +For each assertion, determine whether it is true or false. If it is false, explain why.\n\n""" # noqa: E501 CHECK_ASSERTIONS_PROMPT = PromptTemplate( input_variables=["assertions"], template=_CHECK_ASSERTIONS_TEMPLATE ) @@ -24,7 +23,7 @@ _REVISED_ANSWER_TEMPLATE = """{checked_assertions} Question: In light of the above assertions and checks, how would you answer the question '{question}'? -Answer:""" +Answer:""" # noqa: E501 REVISED_ANSWER_PROMPT = PromptTemplate( input_variables=["checked_assertions", "question"], template=_REVISED_ANSWER_TEMPLATE, diff --git a/libs/langchain/langchain/chains/llm_math/base.py b/libs/langchain/langchain/chains/llm_math/base.py index 012a49c6b4e..b5e7fbb8066 100644 --- a/libs/langchain/langchain/chains/llm_math/base.py +++ b/libs/langchain/langchain/chains/llm_math/base.py @@ -176,7 +176,7 @@ class LLMMathChain(Chain): warnings.warn( "Directly instantiating an LLMMathChain with an llm is deprecated. " "Please instantiate with llm_chain argument or using the from_llm " - "class method." + "class method.", ) if "llm_chain" not in values and values["llm"] is not None: prompt = values.get("prompt", PROMPT) @@ -209,7 +209,7 @@ class LLMMathChain(Chain): expression.strip(), global_dict={}, # restrict access to globals local_dict=local_dict, # add common mathematical functions - ) + ), ) except Exception as e: msg = ( @@ -222,7 +222,9 @@ class LLMMathChain(Chain): return re.sub(r"^\[|\]$", "", output) def _process_llm_result( - self, llm_output: str, run_manager: CallbackManagerForChainRun + self, + llm_output: str, + run_manager: CallbackManagerForChainRun, ) -> dict[str, str]: run_manager.on_text(llm_output, color="green", verbose=self.verbose) llm_output = llm_output.strip() diff --git a/libs/langchain/langchain/chains/llm_math/prompt.py b/libs/langchain/langchain/chains/llm_math/prompt.py index 8c0fd9e8bdc..0796892bbbc 100644 --- a/libs/langchain/langchain/chains/llm_math/prompt.py +++ b/libs/langchain/langchain/chains/llm_math/prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate _PROMPT_TEMPLATE = """Translate a math problem into a expression that can be executed using Python's numexpr library. Use the output of running this code to answer the question. @@ -36,7 +35,7 @@ Question: 37593^(1/5) Answer: 8.222831614237718 Question: {question} -""" +""" # noqa: E501 PROMPT = PromptTemplate( input_variables=["question"], diff --git a/libs/langchain/langchain/chains/llm_summarization_checker/base.py b/libs/langchain/langchain/chains/llm_summarization_checker/base.py index 2cb34fc54b6..2c878f5bb51 100644 --- a/libs/langchain/langchain/chains/llm_summarization_checker/base.py +++ b/libs/langchain/langchain/chains/llm_summarization_checker/base.py @@ -117,7 +117,7 @@ class LLMSummarizationCheckerChain(Chain): warnings.warn( "Directly instantiating an LLMSummarizationCheckerChain with an llm is " "deprecated. Please instantiate with" - " sequential_chain argument or using the from_llm class method." + " sequential_chain argument or using the from_llm class method.", ) if "sequential_chain" not in values and values["llm"] is not None: values["sequential_chain"] = _load_sequential_chain( @@ -159,7 +159,8 @@ class LLMSummarizationCheckerChain(Chain): chain_input = original_input while not all_true and count < self.max_checks: output = self.sequential_chain( - {"summary": chain_input}, callbacks=_run_manager.get_child() + {"summary": chain_input}, + callbacks=_run_manager.get_child(), ) count += 1 diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py index 932f64cfe31..1e6447c7b72 100644 --- a/libs/langchain/langchain/chains/loading.py +++ b/libs/langchain/langchain/chains/loading.py @@ -129,12 +129,15 @@ def _load_stuff_documents_chain(config: dict, **kwargs: Any) -> StuffDocumentsCh raise ValueError(msg) return StuffDocumentsChain( - llm_chain=llm_chain, document_prompt=document_prompt, **config + llm_chain=llm_chain, + document_prompt=document_prompt, + **config, ) def _load_map_reduce_documents_chain( - config: dict, **kwargs: Any + config: dict, + **kwargs: Any, ) -> MapReduceDocumentsChain: if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") @@ -151,11 +154,13 @@ def _load_map_reduce_documents_chain( if "reduce_documents_chain" in config: reduce_documents_chain = load_chain_from_config( - config.pop("reduce_documents_chain"), **kwargs + config.pop("reduce_documents_chain"), + **kwargs, ) elif "reduce_documents_chain_path" in config: reduce_documents_chain = load_chain( - config.pop("reduce_documents_chain_path"), **kwargs + config.pop("reduce_documents_chain_path"), + **kwargs, ) else: reduce_documents_chain = _load_reduce_documents_chain(config, **kwargs) @@ -174,20 +179,24 @@ def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocuments if "combine_documents_chain" in config: combine_document_chain_config = config.pop("combine_documents_chain") combine_documents_chain = load_chain_from_config( - combine_document_chain_config, **kwargs + combine_document_chain_config, + **kwargs, ) elif "combine_document_chain" in config: combine_document_chain_config = config.pop("combine_document_chain") combine_documents_chain = load_chain_from_config( - combine_document_chain_config, **kwargs + combine_document_chain_config, + **kwargs, ) elif "combine_documents_chain_path" in config: combine_documents_chain = load_chain( - config.pop("combine_documents_chain_path"), **kwargs + config.pop("combine_documents_chain_path"), + **kwargs, ) elif "combine_document_chain_path" in config: combine_documents_chain = load_chain( - config.pop("combine_document_chain_path"), **kwargs + config.pop("combine_document_chain_path"), + **kwargs, ) else: msg = ( @@ -202,11 +211,13 @@ def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocuments collapse_documents_chain = None else: collapse_documents_chain = load_chain_from_config( - collapse_document_chain_config, **kwargs + collapse_document_chain_config, + **kwargs, ) elif "collapse_documents_chain_path" in config: collapse_documents_chain = load_chain( - config.pop("collapse_documents_chain_path"), **kwargs + config.pop("collapse_documents_chain_path"), + **kwargs, ) elif "collapse_document_chain" in config: collapse_document_chain_config = config.pop("collapse_document_chain") @@ -214,11 +225,13 @@ def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocuments collapse_documents_chain = None else: collapse_documents_chain = load_chain_from_config( - collapse_document_chain_config, **kwargs + collapse_document_chain_config, + **kwargs, ) elif "collapse_document_chain_path" in config: collapse_documents_chain = load_chain( - config.pop("collapse_document_chain_path"), **kwargs + config.pop("collapse_document_chain_path"), + **kwargs, ) return ReduceDocumentsChain( @@ -251,11 +264,11 @@ def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain: if "create_draft_answer_prompt" in config: create_draft_answer_prompt_config = config.pop("create_draft_answer_prompt") create_draft_answer_prompt = load_prompt_from_config( - create_draft_answer_prompt_config + create_draft_answer_prompt_config, ) elif "create_draft_answer_prompt_path" in config: create_draft_answer_prompt = load_prompt( - config.pop("create_draft_answer_prompt_path") + config.pop("create_draft_answer_prompt_path"), ) if "list_assertions_prompt" in config: list_assertions_prompt_config = config.pop("list_assertions_prompt") @@ -265,11 +278,11 @@ def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain: if "check_assertions_prompt" in config: check_assertions_prompt_config = config.pop("check_assertions_prompt") check_assertions_prompt = load_prompt_from_config( - check_assertions_prompt_config + check_assertions_prompt_config, ) elif "check_assertions_prompt_path" in config: check_assertions_prompt = load_prompt( - config.pop("check_assertions_prompt_path") + config.pop("check_assertions_prompt_path"), ) if "revised_answer_prompt" in config: revised_answer_prompt_config = config.pop("revised_answer_prompt") @@ -315,7 +328,8 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain: def _load_map_rerank_documents_chain( - config: dict, **kwargs: Any + config: dict, + **kwargs: Any, ) -> MapRerankDocumentsChain: if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") @@ -372,11 +386,13 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha if "combine_documents_chain" in config: combine_documents_chain_config = config.pop("combine_documents_chain") combine_documents_chain = load_chain_from_config( - combine_documents_chain_config, **kwargs + combine_documents_chain_config, + **kwargs, ) elif "combine_documents_chain_path" in config: combine_documents_chain = load_chain( - config.pop("combine_documents_chain_path"), **kwargs + config.pop("combine_documents_chain_path"), + **kwargs, ) else: msg = ( @@ -401,7 +417,8 @@ def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any: def _load_vector_db_qa_with_sources_chain( - config: dict, **kwargs: Any + config: dict, + **kwargs: Any, ) -> VectorDBQAWithSourcesChain: if "vectorstore" in kwargs: vectorstore = kwargs.pop("vectorstore") @@ -411,11 +428,13 @@ def _load_vector_db_qa_with_sources_chain( if "combine_documents_chain" in config: combine_documents_chain_config = config.pop("combine_documents_chain") combine_documents_chain = load_chain_from_config( - combine_documents_chain_config, **kwargs + combine_documents_chain_config, + **kwargs, ) elif "combine_documents_chain_path" in config: combine_documents_chain = load_chain( - config.pop("combine_documents_chain_path"), **kwargs + config.pop("combine_documents_chain_path"), + **kwargs, ) else: msg = ( @@ -439,11 +458,13 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA: if "combine_documents_chain" in config: combine_documents_chain_config = config.pop("combine_documents_chain") combine_documents_chain = load_chain_from_config( - combine_documents_chain_config, **kwargs + combine_documents_chain_config, + **kwargs, ) elif "combine_documents_chain_path" in config: combine_documents_chain = load_chain( - config.pop("combine_documents_chain_path"), **kwargs + config.pop("combine_documents_chain_path"), + **kwargs, ) else: msg = ( @@ -459,7 +480,8 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA: def _load_retrieval_qa_with_sources_chain( - config: dict, **kwargs: Any + config: dict, + **kwargs: Any, ) -> RetrievalQAWithSourcesChain: if "retriever" in kwargs: retriever = kwargs.pop("retriever") @@ -469,11 +491,13 @@ def _load_retrieval_qa_with_sources_chain( if "combine_documents_chain" in config: combine_documents_chain_config = config.pop("combine_documents_chain") combine_documents_chain = load_chain_from_config( - combine_documents_chain_config, **kwargs + combine_documents_chain_config, + **kwargs, ) elif "combine_documents_chain_path" in config: combine_documents_chain = load_chain( - config.pop("combine_documents_chain_path"), **kwargs + config.pop("combine_documents_chain_path"), + **kwargs, ) else: msg = ( @@ -497,11 +521,13 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA: if "combine_documents_chain" in config: combine_documents_chain_config = config.pop("combine_documents_chain") combine_documents_chain = load_chain_from_config( - combine_documents_chain_config, **kwargs + combine_documents_chain_config, + **kwargs, ) elif "combine_documents_chain_path" in config: combine_documents_chain = load_chain( - config.pop("combine_documents_chain_path"), **kwargs + config.pop("combine_documents_chain_path"), + **kwargs, ) else: msg = ( @@ -525,7 +551,8 @@ def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain: if "cypher_generation_chain" in config: cypher_generation_chain_config = config.pop("cypher_generation_chain") cypher_generation_chain = load_chain_from_config( - cypher_generation_chain_config, **kwargs + cypher_generation_chain_config, + **kwargs, ) else: msg = "`cypher_generation_chain` must be present." @@ -606,7 +633,9 @@ def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain: if "requests_wrapper" in kwargs: requests_wrapper = kwargs.pop("requests_wrapper") return LLMRequestsChain( - llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config + llm_chain=llm_chain, + requests_wrapper=requests_wrapper, + **config, ) return LLMRequestsChain(llm_chain=llm_chain, **config) diff --git a/libs/langchain/langchain/chains/mapreduce.py b/libs/langchain/langchain/chains/mapreduce.py index 04627f58634..86843654291 100644 --- a/libs/langchain/langchain/chains/mapreduce.py +++ b/libs/langchain/langchain/chains/mapreduce.py @@ -64,7 +64,7 @@ class MapReduceChain(Chain): **(reduce_chain_kwargs if reduce_chain_kwargs else {}), ) reduce_documents_chain = ReduceDocumentsChain( - combine_documents_chain=stuff_chain + combine_documents_chain=stuff_chain, ) combine_documents_chain = MapReduceDocumentsChain( llm_chain=llm_chain, @@ -115,6 +115,7 @@ class MapReduceChain(Chain): self.combine_documents_chain.input_key: docs, } outputs = self.combine_documents_chain.run( - _inputs, callbacks=_run_manager.get_child() + _inputs, + callbacks=_run_manager.get_child(), ) return {self.output_key: outputs} diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index 552fde0bcb1..6aea179beb7 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -45,7 +45,9 @@ class OpenAIModerationChain(Chain): def validate_environment(cls, values: dict) -> Any: """Validate that api key and python package exists in environment.""" openai_api_key = get_from_dict_or_env( - values, "openai_api_key", "OPENAI_API_KEY" + values, + "openai_api_key", + "OPENAI_API_KEY", ) openai_organization = get_from_dict_or_env( values, diff --git a/libs/langchain/langchain/chains/natbot/base.py b/libs/langchain/langchain/chains/natbot/base.py index c71cb6631e9..9ea2943d8d7 100644 --- a/libs/langchain/langchain/chains/natbot/base.py +++ b/libs/langchain/langchain/chains/natbot/base.py @@ -73,7 +73,7 @@ class NatBotChain(Chain): warnings.warn( "Directly instantiating an NatBotChain with an llm is deprecated. " "Please instantiate with llm_chain argument or using the from_llm " - "class method." + "class method.", ) if "llm_chain" not in values and values["llm"] is not None: values["llm_chain"] = PROMPT | values["llm"] | StrOutputParser() @@ -91,7 +91,10 @@ class NatBotChain(Chain): @classmethod def from_llm( - cls, llm: BaseLanguageModel, objective: str, **kwargs: Any + cls, + llm: BaseLanguageModel, + objective: str, + **kwargs: Any, ) -> NatBotChain: """Load from LLM.""" llm_chain = PROMPT | llm | StrOutputParser() diff --git a/libs/langchain/langchain/chains/natbot/crawler.py b/libs/langchain/langchain/chains/natbot/crawler.py index b5fbc553357..16959812667 100644 --- a/libs/langchain/langchain/chains/natbot/crawler.py +++ b/libs/langchain/langchain/chains/natbot/crawler.py @@ -1,23 +1,17 @@ -# flake8: noqa import time from sys import platform from typing import ( TYPE_CHECKING, Any, - Dict, - Iterable, - List, Optional, - Set, - Tuple, TypedDict, Union, ) if TYPE_CHECKING: - from playwright.sync_api import Browser, CDPSession, Page, sync_playwright + from playwright.sync_api import Browser, CDPSession, Page -black_listed_elements: Set[str] = { +black_listed_elements: set[str] = { "html", "head", "title", @@ -40,7 +34,7 @@ class ElementInViewPort(TypedDict): backend_node_id: int node_name: Optional[str] node_value: Optional[str] - node_meta: List[str] + node_meta: list[str] is_clickable: bool origin_x: int origin_y: int @@ -70,16 +64,17 @@ class Crawler: try: from playwright.sync_api import sync_playwright except ImportError: - raise ImportError( + msg = ( "Could not import playwright python package. " "Please install it with `pip install playwright`." ) + raise ImportError(msg) self.browser: Browser = ( sync_playwright().start().chromium.launch(headless=False) ) self.page: Page = self.browser.new_page() self.page.set_viewport_size({"width": 1280, "height": 1080}) - self.page_element_buffer: Dict[int, ElementInViewPort] + self.page_element_buffer: dict[int, ElementInViewPort] self.client: CDPSession def go_to_page(self, url: str) -> None: @@ -90,15 +85,15 @@ class Crawler: def scroll(self, direction: str) -> None: if direction == "up": self.page.evaluate( - "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;" + "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;" # noqa: E501 ) elif direction == "down": self.page.evaluate( - "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;" + "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;" # noqa: E501 ) - def click(self, id: Union[str, int]) -> None: - # Inject javascript into the page which removes the target= attribute from all links + def click(self, id_: Union[str, int]) -> None: + # Inject javascript into the page which removes the target= attribute from links js = """ links = document.getElementsByTagName("a"); for (var i = 0; i < links.length; i++) { @@ -107,7 +102,7 @@ class Crawler: """ self.page.evaluate(js) - element = self.page_element_buffer.get(int(id)) + element = self.page_element_buffer.get(int(id_)) if element: x: float = element["center_x"] y: float = element["center_y"] @@ -116,14 +111,14 @@ class Crawler: else: print("Could not find element") # noqa: T201 - def type(self, id: Union[str, int], text: str) -> None: - self.click(id) + def type(self, id_: Union[str, int], text: str) -> None: + self.click(id_) self.page.keyboard.type(text) def enter(self) -> None: self.page.keyboard.press("Enter") - def crawl(self) -> List[str]: + def crawl(self) -> list[str]: page = self.page page_element_buffer = self.page_element_buffer start = time.time() @@ -141,10 +136,10 @@ class Crawler: win_right_bound: float = win_left_bound + win_width win_lower_bound: float = win_upper_bound + win_height - # percentage_progress_start = (win_upper_bound / document_scroll_height) * 100 - # percentage_progress_end = ( - # (win_height + win_upper_bound) / document_scroll_height - # ) * 100 + # percentage_progress_start = (win_upper_bound / document_scroll_height) * 100 + # percentage_progress_end = ( + # (win_height + win_upper_bound) / document_scroll_height + # ) * 100 percentage_progress_start = 1 percentage_progress_end = 2 @@ -152,9 +147,7 @@ class Crawler: { "x": 0, "y": 0, - "text": "[scrollbar {:0.2f}-{:0.2f}%]".format( - round(percentage_progress_start, 2), round(percentage_progress_end) - ), + "text": f"[scrollbar {percentage_progress_start:0.2f}-{percentage_progress_end:0.2f}%]", # noqa: E501 } ) @@ -162,34 +155,35 @@ class Crawler: "DOMSnapshot.captureSnapshot", {"computedStyles": [], "includeDOMRects": True, "includePaintOrder": True}, ) - strings: Dict[int, str] = tree["strings"] - document: Dict[str, Any] = tree["documents"][0] - nodes: Dict[str, Any] = document["nodes"] - backend_node_id: Dict[int, int] = nodes["backendNodeId"] - attributes: Dict[int, Dict[int, Any]] = nodes["attributes"] - node_value: Dict[int, int] = nodes["nodeValue"] - parent: Dict[int, int] = nodes["parentIndex"] - node_names: Dict[int, int] = nodes["nodeName"] - is_clickable: Set[int] = set(nodes["isClickable"]["index"]) + strings: dict[int, str] = tree["strings"] + document: dict[str, Any] = tree["documents"][0] + nodes: dict[str, Any] = document["nodes"] + backend_node_id: dict[int, int] = nodes["backendNodeId"] + attributes: dict[int, dict[int, Any]] = nodes["attributes"] + node_value: dict[int, int] = nodes["nodeValue"] + parent: dict[int, int] = nodes["parentIndex"] + node_names: dict[int, int] = nodes["nodeName"] + is_clickable: set[int] = set(nodes["isClickable"]["index"]) - input_value: Dict[str, Any] = nodes["inputValue"] - input_value_index: List[int] = input_value["index"] - input_value_values: List[int] = input_value["value"] + input_value: dict[str, Any] = nodes["inputValue"] + input_value_index: list[int] = input_value["index"] + input_value_values: list[int] = input_value["value"] - layout: Dict[str, Any] = document["layout"] - layout_node_index: List[int] = layout["nodeIndex"] - bounds: Dict[int, List[float]] = layout["bounds"] + layout: dict[str, Any] = document["layout"] + layout_node_index: list[int] = layout["nodeIndex"] + bounds: dict[int, list[float]] = layout["bounds"] cursor: int = 0 - child_nodes: Dict[str, List[Dict[str, Any]]] = {} - elements_in_view_port: List[ElementInViewPort] = [] + child_nodes: dict[str, list[dict[str, Any]]] = {} + elements_in_view_port: list[ElementInViewPort] = [] - anchor_ancestry: Dict[str, Tuple[bool, Optional[int]]] = {"-1": (False, None)} - button_ancestry: Dict[str, Tuple[bool, Optional[int]]] = {"-1": (False, None)} + anchor_ancestry: dict[str, tuple[bool, Optional[int]]] = {"-1": (False, None)} + button_ancestry: dict[str, tuple[bool, Optional[int]]] = {"-1": (False, None)} def convert_name( - node_name: Optional[str], has_click_handler: Optional[bool] + node_name: Optional[str], + has_click_handler: Optional[bool], # noqa: FBT001 ) -> str: if node_name == "a": return "link" @@ -201,12 +195,11 @@ class Crawler: node_name == "button" or has_click_handler ): # found pages that needed this quirk return "button" - else: - return "text" + return "text" def find_attributes( - attributes: Dict[int, Any], keys: List[str] - ) -> Dict[str, str]: + attributes: dict[int, Any], keys: list[str] + ) -> dict[str, str]: values = {} for [key_index, value_index] in zip(*(iter(attributes),) * 2): @@ -225,14 +218,14 @@ class Crawler: return values def add_to_hash_tree( - hash_tree: Dict[str, Tuple[bool, Optional[int]]], + hash_tree: dict[str, tuple[bool, Optional[int]]], tag: str, node_id: int, node_name: Optional[str], parent_id: int, - ) -> Tuple[bool, Optional[int]]: + ) -> tuple[bool, Optional[int]]: parent_id_str = str(parent_id) - if not parent_id_str in hash_tree: + if parent_id_str not in hash_tree: parent_name = strings[node_names[parent_id]].lower() grand_parent_id = parent[parent_id] @@ -242,9 +235,10 @@ class Crawler: is_parent_desc_anchor, anchor_id = hash_tree[parent_id_str] - # even if the anchor is nested in another anchor, we set the "root" for all descendants to be ::Self + # even if the anchor is nested in another anchor, we set the "root" for all + # descendants to be ::Self if node_name == tag: - value: Tuple[bool, Optional[int]] = (True, node_id) + value: tuple[bool, Optional[int]] = (True, node_id) elif ( is_parent_desc_anchor ): # reuse the parent's anchor_id (which could be much higher in the tree) @@ -253,7 +247,9 @@ class Crawler: value = ( False, None, - ) # not a descendant of an anchor, most likely it will become text, an interactive element or discarded + ) + # not a descendant of an anchor, most likely it will become text, an + # interactive element or discarded hash_tree[str(node_id)] = value @@ -272,10 +268,10 @@ class Crawler: ) try: - cursor = layout_node_index.index( - index - ) # todo replace this with proper cursoring, ignoring the fact this is O(n^2) for the moment - except: + cursor = layout_node_index.index(index) + # TODO replace this with proper cursoring, ignoring the fact this is + # O(n^2) for the moment + except: # noqa: E722 continue if node_name in black_listed_elements: @@ -302,9 +298,10 @@ class Crawler: if not partially_is_in_viewport: continue - meta_data: List[str] = [] + meta_data: list[str] = [] - # inefficient to grab the same set of keys for kinds of objects, but it's fine for now + # inefficient to grab the same set of keys for kinds of objects, but it's + # fine for now element_attributes = find_attributes( attributes[index], ["type", "placeholder", "aria-label", "title", "alt"] ) @@ -355,7 +352,9 @@ class Crawler: element_node_value = strings[node_value[index]] if ( element_node_value == "|" - ): # commonly used as a separator, does not add much context - lets save ourselves some token space + # commonly used as a separator, does not add much context - lets + # save ourselves some token space + ): continue elif ( node_name == "input" @@ -386,7 +385,8 @@ class Crawler: } ) - # lets filter further to remove anything that does not hold any text nor has click handlers + merge text from leaf#text nodes with the parent + # lets filter further to remove anything that does not hold any text nor has + # click handlers + merge text from leaf#text nodes with the parent elements_of_interest = [] id_counter = 0 @@ -395,7 +395,7 @@ class Crawler: node_name = element.get("node_name") element_node_value = element.get("node_value") node_is_clickable = element.get("is_clickable") - node_meta_data: Optional[List[str]] = element.get("node_meta") + node_meta_data: Optional[list[str]] = element.get("node_meta") inner_text = f"{element_node_value} " if element_node_value else "" meta = "" @@ -434,7 +434,7 @@ class Crawler: if inner_text != "": elements_of_interest.append( - f"""<{converted_node_name} id={id_counter}{meta}>{inner_text}""" + f"""<{converted_node_name} id={id_counter}{meta}>{inner_text}""" # noqa: E501 ) else: elements_of_interest.append( @@ -442,5 +442,5 @@ class Crawler: ) id_counter += 1 - print("Parsing time: {:0.2f} seconds".format(time.time() - start)) # noqa: T201 + print(f"Parsing time: {time.time() - start:0.2f} seconds") # noqa: T201 return elements_of_interest diff --git a/libs/langchain/langchain/chains/natbot/prompt.py b/libs/langchain/langchain/chains/natbot/prompt.py index 82a35f58c7e..10028c9a21c 100644 --- a/libs/langchain/langchain/chains/natbot/prompt.py +++ b/libs/langchain/langchain/chains/natbot/prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate _PROMPT_TEMPLATE = """ @@ -115,7 +114,7 @@ CURRENT BROWSER CONTENT: 7:00 PM 2 people - + It looks like you're in Peninsula. Not correct? @@ -137,7 +136,7 @@ OBJECTIVE: {objective} CURRENT URL: {url} PREVIOUS COMMAND: {previous_command} YOUR COMMAND: -""" +""" # noqa: E501 PROMPT = PromptTemplate( input_variables=["browser_content", "url", "previous_command", "objective"], template=_PROMPT_TEMPLATE, diff --git a/libs/langchain/langchain/chains/openai_functions/base.py b/libs/langchain/langchain/chains/openai_functions/base.py index 962541833cf..53b4df74e04 100644 --- a/libs/langchain/langchain/chains/openai_functions/base.py +++ b/libs/langchain/langchain/chains/openai_functions/base.py @@ -143,7 +143,9 @@ def create_openai_fn_chain( @deprecated( - since="0.1.1", removal="1.0", alternative="ChatOpenAI.with_structured_output" + since="0.1.1", + removal="1.0", + alternative="ChatOpenAI.with_structured_output", ) def create_structured_output_chain( output_schema: Union[dict[str, Any], type[BaseModel]], @@ -220,7 +222,8 @@ def create_structured_output_chain( function = _OutputFormatter output_parser = output_parser or PydanticAttrOutputFunctionsParser( - pydantic_schema=_OutputFormatter, attr_name="output" + pydantic_schema=_OutputFormatter, + attr_name="output", ) return create_openai_fn_chain( [function], diff --git a/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py b/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py index 6200f9d9bf7..ed954923e09 100644 --- a/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py +++ b/libs/langchain/langchain/chains/openai_functions/citation_fuzzy_match.py @@ -94,16 +94,16 @@ def create_citation_fuzzy_match_runnable(llm: BaseChatModel) -> Runnable: [ SystemMessage( "You are a world class algorithm to answer " - "questions with correct and exact citations." + "questions with correct and exact citations.", ), HumanMessagePromptTemplate.from_template( "Answer question using the following context." "\n\n{context}" "\n\nQuestion: {question}" "\n\nTips: Make sure to cite your sources, " - "and use the exact words from the context." + "and use the exact words from the context.", ), - ] + ], ) return prompt | llm.with_structured_output(QuestionAnswer) @@ -135,7 +135,7 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain: content=( "You are a world class algorithm to answer " "questions with correct and exact citations." - ) + ), ), HumanMessage(content="Answer question using the following context"), HumanMessagePromptTemplate.from_template("{context}"), @@ -144,7 +144,7 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain: content=( "Tips: Make sure to cite your sources, " "and use the exact words from the context." - ) + ), ), ] prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type] diff --git a/libs/langchain/langchain/chains/openai_functions/extraction.py b/libs/langchain/langchain/chains/openai_functions/extraction.py index 8d4f9d5f95c..ac5bdba41e4 100644 --- a/libs/langchain/langchain/chains/openai_functions/extraction.py +++ b/libs/langchain/langchain/chains/openai_functions/extraction.py @@ -25,7 +25,7 @@ def _get_extraction_function(entity_schema: dict) -> dict: "parameters": { "type": "object", "properties": { - "info": {"type": "array", "items": _convert_schema(entity_schema)} + "info": {"type": "array", "items": _convert_schema(entity_schema)}, }, "required": ["info"], }, @@ -177,13 +177,15 @@ def create_extraction_chain_pydantic( openai_schema = pydantic_schema.schema() openai_schema = _resolve_schema_references( - openai_schema, openai_schema.get("definitions", {}) + openai_schema, + openai_schema.get("definitions", {}), ) function = _get_extraction_function(openai_schema) extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) output_parser = PydanticAttrOutputFunctionsParser( - pydantic_schema=PydanticSchema, attr_name="info" + pydantic_schema=PydanticSchema, + attr_name="info", ) llm_kwargs = get_llm_kwargs(function) return LLMChain( diff --git a/libs/langchain/langchain/chains/openai_functions/openapi.py b/libs/langchain/langchain/chains/openai_functions/openapi.py index 0c8f463eb07..4fc83ddd855 100644 --- a/libs/langchain/langchain/chains/openai_functions/openapi.py +++ b/libs/langchain/langchain/chains/openai_functions/openapi.py @@ -127,7 +127,8 @@ def openapi_spec_to_openai_fn( for param_loc, arg_name in param_loc_to_arg_name.items(): if params_by_type[param_loc]: request_args[arg_name] = _openapi_params_to_json_schema( - params_by_type[param_loc], spec + params_by_type[param_loc], + spec, ) request_body = spec.get_request_body_for_operation(op) # TODO: Support more MIME types. @@ -137,7 +138,7 @@ def openapi_spec_to_openai_fn( if media_type_object.media_type_schema: schema = spec.get_schema(media_type_object.media_type_schema) media_types[media_type] = json.loads( - schema.json(exclude_none=True) + schema.json(exclude_none=True), ) if len(media_types) == 1: media_type, schema_dict = next(iter(media_types.items())) @@ -226,8 +227,8 @@ class SimpleRequestChain(Chain): if api_response.status_code != 200: response = ( f"{api_response.status_code}: {api_response.reason}" - + f"\nFor {name} " - + f"Called with args: {args.get('params', '')}" + f"\nFor {name} " + f"Called with args: {args.get('params', '')}" ) else: try: @@ -376,7 +377,7 @@ def get_openapi_chain( ) raise ValueError(msg) prompt = prompt or ChatPromptTemplate.from_template( - "Use the provided API's to respond to this user query:\n\n{query}" + "Use the provided API's to respond to this user query:\n\n{query}", ) llm_chain = LLMChain( llm=llm, @@ -389,7 +390,10 @@ def get_openapi_chain( ) request_chain = request_chain or SimpleRequestChain( request_method=lambda name, args: call_api_fn( - name, args, headers=headers, params=params + name, + args, + headers=headers, + params=params, ), verbose=verbose, ) diff --git a/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py b/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py index 421f9ec73c5..1c4f698f23a 100644 --- a/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py +++ b/libs/langchain/langchain/chains/openai_functions/qa_with_structure.py @@ -22,7 +22,8 @@ class AnswerWithSources(BaseModel): answer: str = Field(..., description="Answer to the question that was asked") sources: list[str] = Field( - ..., description="List of sources used to answer the question" + ..., + description="List of sources used to answer the question", ) @@ -63,7 +64,7 @@ def create_qa_with_structure_chain( ) raise ValueError(msg) _output_parser: BaseLLMOutputParser = PydanticOutputFunctionsParser( - pydantic_schema=schema + pydantic_schema=schema, ) elif output_parser == "base": _output_parser = OutputFunctionsParser() @@ -91,7 +92,7 @@ def create_qa_with_structure_chain( content=( "You are a world class algorithm to answer " "questions in a specific format." - ) + ), ), HumanMessage(content="Answer question using the following context"), HumanMessagePromptTemplate.from_template("{context}"), @@ -134,5 +135,8 @@ def create_qa_with_sources_chain( Chain (LLMChain) that can be used to answer questions with citations. """ return create_qa_with_structure_chain( - llm, AnswerWithSources, verbose=verbose, **kwargs + llm, + AnswerWithSources, + verbose=verbose, + **kwargs, ) diff --git a/libs/langchain/langchain/chains/openai_tools/extraction.py b/libs/langchain/langchain/chains/openai_tools/extraction.py index d76d373f3b3..8aa4d0cb1e2 100644 --- a/libs/langchain/langchain/chains/openai_tools/extraction.py +++ b/libs/langchain/langchain/chains/openai_tools/extraction.py @@ -71,7 +71,7 @@ def create_extraction_chain_pydantic( [ ("system", system_message), ("user", "{input}"), - ] + ], ) functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas] tools = [{"type": "function", "function": d} for d in functions] diff --git a/libs/langchain/langchain/chains/qa_generation/base.py b/libs/langchain/langchain/chains/qa_generation/base.py index ca106a7fae6..75eaf791d8c 100644 --- a/libs/langchain/langchain/chains/qa_generation/base.py +++ b/libs/langchain/langchain/chains/qa_generation/base.py @@ -66,7 +66,7 @@ class QAGenerationChain(Chain): llm_chain: LLMChain """LLM Chain that generates responses from user input and context.""" text_splitter: TextSplitter = Field( - default=RecursiveCharacterTextSplitter(chunk_overlap=500) + default=RecursiveCharacterTextSplitter(chunk_overlap=500), ) """Text splitter that splits the input into chunks.""" input_key: str = "text" @@ -117,7 +117,8 @@ class QAGenerationChain(Chain): ) -> dict[str, list]: docs = self.text_splitter.create_documents([inputs[self.input_key]]) results = self.llm_chain.generate( - [{"text": d.page_content} for d in docs], run_manager=run_manager + [{"text": d.page_content} for d in docs], + run_manager=run_manager, ) qa = [json.loads(res[0].text) for res in results.generations] return {self.output_key: qa} diff --git a/libs/langchain/langchain/chains/qa_generation/prompt.py b/libs/langchain/langchain/chains/qa_generation/prompt.py index 377a49e4c9b..19c3b2b8ba0 100644 --- a/libs/langchain/langchain/chains/qa_generation/prompt.py +++ b/libs/langchain/langchain/chains/qa_generation/prompt.py @@ -1,5 +1,3 @@ -# flake8: noqa -from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model from langchain_core.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, @@ -7,6 +5,8 @@ from langchain_core.prompts.chat import ( ) from langchain_core.prompts.prompt import PromptTemplate +from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model + templ1 = """You are a smart assistant designed to help high school teachers come up with reading comprehension questions. Given a piece of text, you must come up with a question and answer pair that can be used to test a student's reading comprehension abilities. When coming up with this question/answer pair, you must respond in the following format: @@ -18,10 +18,10 @@ When coming up with this question/answer pair, you must respond in the following ``` Everything between the ``` must be valid json. -""" +""" # noqa: E501 templ2 = """Please come up with a question/answer pair, in the specified JSON format, for the following text: ---------------- -{text}""" +{text}""" # noqa: E501 CHAT_PROMPT = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(templ1), @@ -42,7 +42,7 @@ Everything between the ``` must be valid json. Please come up with a question/answer pair, in the specified JSON format, for the following text: ---------------- -{text}""" +{text}""" # noqa: E501 PROMPT = PromptTemplate.from_template(templ) PROMPT_SELECTOR = ConditionalPromptSelector( diff --git a/libs/langchain/langchain/chains/qa_with_sources/base.py b/libs/langchain/langchain/chains/qa_with_sources/base.py index ac77e06a8de..d1e115870cc 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/base.py +++ b/libs/langchain/langchain/chains/qa_with_sources/base.py @@ -70,7 +70,7 @@ class BaseQAWithSourcesChain(Chain, ABC): document_variable_name="summaries", ) reduce_documents_chain = ReduceDocumentsChain( - combine_documents_chain=combine_results_chain + combine_documents_chain=combine_results_chain, ) combine_documents_chain = MapReduceDocumentsChain( llm_chain=llm_question_chain, @@ -93,7 +93,9 @@ class BaseQAWithSourcesChain(Chain, ABC): """Load chain from chain type.""" _chain_kwargs = chain_type_kwargs or {} combine_documents_chain = load_qa_with_sources_chain( - llm, chain_type=chain_type, **_chain_kwargs + llm, + chain_type=chain_type, + **_chain_kwargs, ) return cls(combine_documents_chain=combine_documents_chain, **kwargs) @@ -133,7 +135,9 @@ class BaseQAWithSourcesChain(Chain, ABC): """Split sources from answer.""" if re.search(r"SOURCES?:", answer, re.IGNORECASE): answer, sources = re.split( - r"SOURCES?:|QUESTION:\s", answer, flags=re.IGNORECASE + r"SOURCES?:|QUESTION:\s", + answer, + flags=re.IGNORECASE, )[:2] sources = re.split(r"\n", sources)[0].strip() else: @@ -164,7 +168,9 @@ class BaseQAWithSourcesChain(Chain, ABC): docs = self._get_docs(inputs) # type: ignore[call-arg] answer = self.combine_documents_chain.run( - input_documents=docs, callbacks=_run_manager.get_child(), **inputs + input_documents=docs, + callbacks=_run_manager.get_child(), + **inputs, ) answer, sources = self._split_sources(answer) result: dict[str, Any] = { @@ -198,7 +204,9 @@ class BaseQAWithSourcesChain(Chain, ABC): else: docs = await self._aget_docs(inputs) # type: ignore[call-arg] answer = await self.combine_documents_chain.arun( - input_documents=docs, callbacks=_run_manager.get_child(), **inputs + input_documents=docs, + callbacks=_run_manager.get_child(), + **inputs, ) answer, sources = self._split_sources(answer) result: dict[str, Any] = { diff --git a/libs/langchain/langchain/chains/qa_with_sources/loading.py b/libs/langchain/langchain/chains/qa_with_sources/loading.py index 9fd8a3587d5..0d480e9f9d9 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/loading.py +++ b/libs/langchain/langchain/chains/qa_with_sources/loading.py @@ -30,7 +30,9 @@ class LoadingCallable(Protocol): """Interface for loading the combine documents chain.""" def __call__( - self, llm: BaseLanguageModel, **kwargs: Any + self, + llm: BaseLanguageModel, + **kwargs: Any, ) -> BaseCombineDocumentsChain: """Callable to load the combine documents chain.""" diff --git a/libs/langchain/langchain/chains/qa_with_sources/map_reduce_prompt.py b/libs/langchain/langchain/chains/qa_with_sources/map_reduce_prompt.py index e0c8545e704..0b4304721fe 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/map_reduce_prompt.py +++ b/libs/langchain/langchain/chains/qa_with_sources/map_reduce_prompt.py @@ -1,16 +1,15 @@ -# flake8: noqa from langchain_core.prompts import PromptTemplate -question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question. +question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question. Return any relevant text verbatim. {context} Question: {question} -Relevant text, if any:""" +Relevant text, if any:""" # noqa: E501 QUESTION_PROMPT = PromptTemplate( template=question_prompt_template, input_variables=["context", "question"] ) -combine_prompt_template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES"). +combine_prompt_template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES"). If you don't know the answer, just say that you don't know. Don't try to make up an answer. ALWAYS return a "SOURCES" part in your answer. @@ -28,13 +27,13 @@ SOURCES: 28-pl QUESTION: What did the president say about Michael Jackson? ========= -Content: Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n\nLast year COVID-19 kept us apart. This year we are finally together again. \n\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n\nWith a duty to one another to the American people to the Constitution. \n\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \n\nSix days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n\nHe met the Ukrainian people. \n\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. +Content: Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n\nLast year COVID-19 kept us apart. This year we are finally together again. \n\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n\nWith a duty to one another to the American people to the Constitution. \n\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \n\nSix days ago, Russia's Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n\nHe met the Ukrainian people. \n\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. Source: 0-pl -Content: And we won’t stop. \n\nWe have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. \n\nLet’s use this moment to reset. Let’s stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. \n\nLet’s stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. \n\nWe can’t change how divided we’ve been. But we can change how we move forward—on COVID-19 and other issues we must face together. \n\nI recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. \n\nThey were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n\nOfficer Mora was 27 years old. \n\nOfficer Rivera was 22. \n\nBoth Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. \n\nI spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. +Content: And we won't stop. \n\nWe have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. \n\nLet's use this moment to reset. Let's stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. \n\nLet's stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. \n\nWe can't change how divided we've been. But we can change how we move forward—on COVID-19 and other issues we must face together. \n\nI recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. \n\nThey were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n\nOfficer Mora was 27 years old. \n\nOfficer Rivera was 22. \n\nBoth Dominican Americans who'd grown up on the same streets they later chose to patrol as police officers. \n\nI spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. Source: 24-pl -Content: And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. \n\nTo all Americans, I will be honest with you, as I’ve always promised. A Russian dictator, invading a foreign country, has costs around the world. \n\nAnd I’m taking robust action to make sure the pain of our sanctions is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers. \n\nTonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. \n\nAmerica will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. \n\nThese steps will help blunt gas prices here at home. And I know the news about what’s happening can seem alarming. \n\nBut I want you to know that we are going to be okay. +Content: And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. \n\nTo all Americans, I will be honest with you, as I've always promised. A Russian dictator, invading a foreign country, has costs around the world. \n\nAnd I'm taking robust action to make sure the pain of our sanctions is targeted at Russia's economy. And I will use every tool at our disposal to protect American businesses and consumers. \n\nTonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. \n\nAmerica will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. \n\nThese steps will help blunt gas prices here at home. And I know the news about what's happening can seem alarming. \n\nBut I want you to know that we are going to be okay. Source: 5-pl -Content: More support for patients and families. \n\nTo get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. \n\nIt’s based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. \n\nARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer’s, diabetes, and more. \n\nA unity agenda for the nation. \n\nWe can do this. \n\nMy fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. \n\nIn this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. \n\nWe have fought for freedom, expanded liberty, defeated totalitarianism and terror. \n\nAnd built the strongest, freest, and most prosperous nation the world has ever known. \n\nNow is the hour. \n\nOur moment of responsibility. \n\nOur test of resolve and conscience, of history itself. \n\nIt is in this moment that our character is formed. Our purpose is found. Our future is forged. \n\nWell I know this nation. +Content: More support for patients and families. \n\nTo get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. \n\nIt's based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. \n\nARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer's, diabetes, and more. \n\nA unity agenda for the nation. \n\nWe can do this. \n\nMy fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. \n\nIn this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. \n\nWe have fought for freedom, expanded liberty, defeated totalitarianism and terror. \n\nAnd built the strongest, freest, and most prosperous nation the world has ever known. \n\nNow is the hour. \n\nOur moment of responsibility. \n\nOur test of resolve and conscience, of history itself. \n\nIt is in this moment that our character is formed. Our purpose is found. Our future is forged. \n\nWell I know this nation. Source: 34-pl ========= FINAL ANSWER: The president did not mention Michael Jackson. @@ -44,7 +43,7 @@ QUESTION: {question} ========= {summaries} ========= -FINAL ANSWER:""" +FINAL ANSWER:""" # noqa: E501 COMBINE_PROMPT = PromptTemplate( template=combine_prompt_template, input_variables=["summaries", "question"] ) diff --git a/libs/langchain/langchain/chains/qa_with_sources/refine_prompts.py b/libs/langchain/langchain/chains/qa_with_sources/refine_prompts.py index 2e13f54153b..537f9e5ca60 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/refine_prompts.py +++ b/libs/langchain/langchain/chains/qa_with_sources/refine_prompts.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts import PromptTemplate DEFAULT_REFINE_PROMPT_TMPL = ( diff --git a/libs/langchain/langchain/chains/qa_with_sources/retrieval.py b/libs/langchain/langchain/chains/qa_with_sources/retrieval.py index 8b2cba75fbd..c27dda97192 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/retrieval.py +++ b/libs/langchain/langchain/chains/qa_with_sources/retrieval.py @@ -29,7 +29,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain): num_docs = len(docs) if self.reduce_k_below_max_tokens and isinstance( - self.combine_documents_chain, StuffDocumentsChain + self.combine_documents_chain, + StuffDocumentsChain, ): tokens = [ self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content) @@ -43,20 +44,28 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain): return docs[:num_docs] def _get_docs( - self, inputs: dict[str, Any], *, run_manager: CallbackManagerForChainRun + self, + inputs: dict[str, Any], + *, + run_manager: CallbackManagerForChainRun, ) -> list[Document]: question = inputs[self.question_key] docs = self.retriever.invoke( - question, config={"callbacks": run_manager.get_child()} + question, + config={"callbacks": run_manager.get_child()}, ) return self._reduce_tokens_below_limit(docs) async def _aget_docs( - self, inputs: dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun + self, + inputs: dict[str, Any], + *, + run_manager: AsyncCallbackManagerForChainRun, ) -> list[Document]: question = inputs[self.question_key] docs = await self.retriever.ainvoke( - question, config={"callbacks": run_manager.get_child()} + question, + config={"callbacks": run_manager.get_child()}, ) return self._reduce_tokens_below_limit(docs) diff --git a/libs/langchain/langchain/chains/qa_with_sources/stuff_prompt.py b/libs/langchain/langchain/chains/qa_with_sources/stuff_prompt.py index 82290ee0507..11e844812ae 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/stuff_prompt.py +++ b/libs/langchain/langchain/chains/qa_with_sources/stuff_prompt.py @@ -1,7 +1,6 @@ -# flake8: noqa from langchain_core.prompts import PromptTemplate -template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES"). +template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES"). If you don't know the answer, just say that you don't know. Don't try to make up an answer. ALWAYS return a "SOURCES" part in your answer. @@ -19,13 +18,13 @@ SOURCES: 28-pl QUESTION: What did the president say about Michael Jackson? ========= -Content: Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n\nLast year COVID-19 kept us apart. This year we are finally together again. \n\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n\nWith a duty to one another to the American people to the Constitution. \n\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \n\nSix days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n\nHe met the Ukrainian people. \n\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. +Content: Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n\nLast year COVID-19 kept us apart. This year we are finally together again. \n\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n\nWith a duty to one another to the American people to the Constitution. \n\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \n\nSix days ago, Russia's Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n\nHe met the Ukrainian people. \n\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. Source: 0-pl -Content: And we won’t stop. \n\nWe have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. \n\nLet’s use this moment to reset. Let’s stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. \n\nLet’s stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. \n\nWe can’t change how divided we’ve been. But we can change how we move forward—on COVID-19 and other issues we must face together. \n\nI recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. \n\nThey were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n\nOfficer Mora was 27 years old. \n\nOfficer Rivera was 22. \n\nBoth Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. \n\nI spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. +Content: And we won't stop. \n\nWe have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. \n\nLet's use this moment to reset. Let's stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. \n\nLet's stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. \n\nWe can't change how divided we've been. But we can change how we move forward—on COVID-19 and other issues we must face together. \n\nI recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. \n\nThey were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n\nOfficer Mora was 27 years old. \n\nOfficer Rivera was 22. \n\nBoth Dominican Americans who'd grown up on the same streets they later chose to patrol as police officers. \n\nI spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. Source: 24-pl -Content: And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. \n\nTo all Americans, I will be honest with you, as I’ve always promised. A Russian dictator, invading a foreign country, has costs around the world. \n\nAnd I’m taking robust action to make sure the pain of our sanctions is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers. \n\nTonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. \n\nAmerica will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. \n\nThese steps will help blunt gas prices here at home. And I know the news about what’s happening can seem alarming. \n\nBut I want you to know that we are going to be okay. +Content: And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. \n\nTo all Americans, I will be honest with you, as I've always promised. A Russian dictator, invading a foreign country, has costs around the world. \n\nAnd I'm taking robust action to make sure the pain of our sanctions is targeted at Russia's economy. And I will use every tool at our disposal to protect American businesses and consumers. \n\nTonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. \n\nAmerica will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. \n\nThese steps will help blunt gas prices here at home. And I know the news about what's happening can seem alarming. \n\nBut I want you to know that we are going to be okay. Source: 5-pl -Content: More support for patients and families. \n\nTo get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. \n\nIt’s based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. \n\nARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer’s, diabetes, and more. \n\nA unity agenda for the nation. \n\nWe can do this. \n\nMy fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. \n\nIn this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. \n\nWe have fought for freedom, expanded liberty, defeated totalitarianism and terror. \n\nAnd built the strongest, freest, and most prosperous nation the world has ever known. \n\nNow is the hour. \n\nOur moment of responsibility. \n\nOur test of resolve and conscience, of history itself. \n\nIt is in this moment that our character is formed. Our purpose is found. Our future is forged. \n\nWell I know this nation. +Content: More support for patients and families. \n\nTo get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. \n\nIt's based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. \n\nARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer's, diabetes, and more. \n\nA unity agenda for the nation. \n\nWe can do this. \n\nMy fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. \n\nIn this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. \n\nWe have fought for freedom, expanded liberty, defeated totalitarianism and terror. \n\nAnd built the strongest, freest, and most prosperous nation the world has ever known. \n\nNow is the hour. \n\nOur moment of responsibility. \n\nOur test of resolve and conscience, of history itself. \n\nIt is in this moment that our character is formed. Our purpose is found. Our future is forged. \n\nWell I know this nation. Source: 34-pl ========= FINAL ANSWER: The president did not mention Michael Jackson. @@ -35,7 +34,7 @@ QUESTION: {question} ========= {summaries} ========= -FINAL ANSWER:""" +FINAL ANSWER:""" # noqa: E501 PROMPT = PromptTemplate(template=template, input_variables=["summaries", "question"]) EXAMPLE_PROMPT = PromptTemplate( diff --git a/libs/langchain/langchain/chains/qa_with_sources/vector_db.py b/libs/langchain/langchain/chains/qa_with_sources/vector_db.py index c6cb0e0d019..119604f1dd6 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/vector_db.py +++ b/libs/langchain/langchain/chains/qa_with_sources/vector_db.py @@ -34,7 +34,8 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain): num_docs = len(docs) if self.reduce_k_below_max_tokens and isinstance( - self.combine_documents_chain, StuffDocumentsChain + self.combine_documents_chain, + StuffDocumentsChain, ): tokens = [ self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content) @@ -48,16 +49,24 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain): return docs[:num_docs] def _get_docs( - self, inputs: dict[str, Any], *, run_manager: CallbackManagerForChainRun + self, + inputs: dict[str, Any], + *, + run_manager: CallbackManagerForChainRun, ) -> list[Document]: question = inputs[self.question_key] docs = self.vectorstore.similarity_search( - question, k=self.k, **self.search_kwargs + question, + k=self.k, + **self.search_kwargs, ) return self._reduce_tokens_below_limit(docs) async def _aget_docs( - self, inputs: dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun + self, + inputs: dict[str, Any], + *, + run_manager: AsyncCallbackManagerForChainRun, ) -> list[Document]: msg = "VectorDBQAWithSourcesChain does not support async" raise NotImplementedError(msg) @@ -67,7 +76,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain): def raise_deprecation(cls, values: dict) -> Any: warnings.warn( "`VectorDBQAWithSourcesChain` is deprecated - " - "please use `from langchain.chains import RetrievalQAWithSourcesChain`" + "please use `from langchain.chains import RetrievalQAWithSourcesChain`", ) return values diff --git a/libs/langchain/langchain/chains/query_constructor/base.py b/libs/langchain/langchain/chains/query_constructor/base.py index 1e8a7a87c8b..0977cfbb9b8 100644 --- a/libs/langchain/langchain/chains/query_constructor/base.py +++ b/libs/langchain/langchain/chains/query_constructor/base.py @@ -60,7 +60,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): if not parsed.get("limit"): parsed.pop("limit", None) return StructuredQuery( - **{k: v for k, v in parsed.items() if k in allowed_keys} + **{k: v for k, v in parsed.items() if k in allowed_keys}, ) except Exception as e: msg = f"Parsing text\n{text}\n raised following error:\n{e}" @@ -89,7 +89,8 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): def ast_parse(raw_filter: str) -> Optional[FilterDirective]: filter_directive = cast( - Optional[FilterDirective], get_parser().parse(raw_filter) + Optional[FilterDirective], + get_parser().parse(raw_filter), ) return fix_filter_directive( filter_directive, @@ -235,7 +236,9 @@ def get_query_constructor_prompt( examples = construct_examples(examples) example_prompt = USER_SPECIFIED_EXAMPLE_PROMPT prefix = PREFIX_WITH_DATA_SOURCE.format( - schema=schema, content=document_contents, attributes=attribute_str + schema=schema, + content=document_contents, + attributes=attribute_str, ) suffix = SUFFIX_WITHOUT_DATA_SOURCE.format(i=len(examples) + 1) else: @@ -245,7 +248,9 @@ def get_query_constructor_prompt( example_prompt = EXAMPLE_PROMPT prefix = DEFAULT_PREFIX.format(schema=schema) suffix = DEFAULT_SUFFIX.format( - i=len(examples) + 1, content=document_contents, attributes=attribute_str + i=len(examples) + 1, + content=document_contents, + attributes=attribute_str, ) return FewShotPromptTemplate( examples=list(examples), @@ -300,11 +305,10 @@ def load_query_constructor_chain( enable_limit=enable_limit, schema_prompt=schema_prompt, ) - allowed_attributes = [] - for ainfo in attribute_info: - allowed_attributes.append( - ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"] - ) + allowed_attributes = [ + ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"] + for ainfo in attribute_info + ] output_parser = StructuredQueryOutputParser.from_components( allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, @@ -359,11 +363,10 @@ def load_query_constructor_runnable( schema_prompt=schema_prompt, **kwargs, ) - allowed_attributes = [] - for ainfo in attribute_info: - allowed_attributes.append( - ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"] - ) + allowed_attributes = [ + ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"] + for ainfo in attribute_info + ] output_parser = StructuredQueryOutputParser.from_components( allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, diff --git a/libs/langchain/langchain/chains/query_constructor/parser.py b/libs/langchain/langchain/chains/query_constructor/parser.py index 826835f7334..565531599e5 100644 --- a/libs/langchain/langchain/chains/query_constructor/parser.py +++ b/libs/langchain/langchain/chains/query_constructor/parser.py @@ -161,7 +161,7 @@ class QueryTransformer(Transformer): except ValueError: warnings.warn( "Dates are expected to be provided in ISO 8601 date format " - "(YYYY-MM-DD)." + "(YYYY-MM-DD).", ) return {"date": item, "type": "date"} diff --git a/libs/langchain/langchain/chains/query_constructor/prompt.py b/libs/langchain/langchain/chains/query_constructor/prompt.py index d1355b32663..e1abd57b3d8 100644 --- a/libs/langchain/langchain/chains/query_constructor/prompt.py +++ b/libs/langchain/langchain/chains/query_constructor/prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts import PromptTemplate SONG_DATA_SOURCE = """\ @@ -30,7 +29,7 @@ FULL_ANSWER = """\ "filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))" }} ```\ -""" +""" # noqa: E501 NO_FILTER_ANSWER = """\ ```json @@ -55,7 +54,7 @@ DEFAULT_EXAMPLES = [ { "i": 1, "data_source": SONG_DATA_SOURCE, - "user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre", + "user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre", # noqa: E501 "structured_request": FULL_ANSWER, }, { @@ -70,7 +69,7 @@ EXAMPLES_WITH_LIMIT = [ { "i": 1, "data_source": SONG_DATA_SOURCE, - "user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre", + "user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre", # noqa: E501 "structured_request": FULL_ANSWER, }, { @@ -144,7 +143,7 @@ Make sure that filters only use the attributed names with its function names if Make sure that filters only use format `YYYY-MM-DD` when handling date data typed values. Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored. Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.\ -""" +""" # noqa: E501 DEFAULT_SCHEMA_PROMPT = PromptTemplate.from_template(DEFAULT_SCHEMA) SCHEMA_WITH_LIMIT = """\ @@ -179,7 +178,7 @@ Make sure that filters only use format `YYYY-MM-DD` when handling date data type Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored. Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value. Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it does not make sense. -""" +""" # noqa: E501 SCHEMA_WITH_LIMIT_PROMPT = PromptTemplate.from_template(SCHEMA_WITH_LIMIT) DEFAULT_PREFIX = """\ diff --git a/libs/langchain/langchain/chains/question_answering/chain.py b/libs/langchain/langchain/chains/question_answering/chain.py index e42444c6be6..f4398aab200 100644 --- a/libs/langchain/langchain/chains/question_answering/chain.py +++ b/libs/langchain/langchain/chains/question_answering/chain.py @@ -29,7 +29,9 @@ class LoadingCallable(Protocol): """Interface for loading the combine documents chain.""" def __call__( - self, llm: BaseLanguageModel, **kwargs: Any + self, + llm: BaseLanguageModel, + **kwargs: Any, ) -> BaseCombineDocumentsChain: """Callable to load the combine documents chain.""" @@ -194,7 +196,7 @@ def _load_refine_chain( question_prompt or refine_prompts.QUESTION_PROMPT_SELECTOR.get_prompt(llm) ) _refine_prompt = refine_prompt or refine_prompts.REFINE_PROMPT_SELECTOR.get_prompt( - llm + llm, ) initial_chain = LLMChain( llm=llm, @@ -270,5 +272,8 @@ def load_qa_chain( ) raise ValueError(msg) return loader_mapping[chain_type]( - llm, verbose=verbose, callback_manager=callback_manager, **kwargs + llm, + verbose=verbose, + callback_manager=callback_manager, + **kwargs, ) diff --git a/libs/langchain/langchain/chains/question_answering/map_reduce_prompt.py b/libs/langchain/langchain/chains/question_answering/map_reduce_prompt.py index defaa8fea7b..b0bf9d4bada 100644 --- a/libs/langchain/langchain/chains/question_answering/map_reduce_prompt.py +++ b/libs/langchain/langchain/chains/question_answering/map_reduce_prompt.py @@ -1,5 +1,3 @@ -# flake8: noqa -from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model from langchain_core.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, @@ -7,18 +5,20 @@ from langchain_core.prompts.chat import ( ) from langchain_core.prompts.prompt import PromptTemplate -question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question. +from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model + +question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question. Return any relevant text verbatim. {context} Question: {question} -Relevant text, if any:""" +Relevant text, if any:""" # noqa: E501 QUESTION_PROMPT = PromptTemplate( template=question_prompt_template, input_variables=["context", "question"] ) -system_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question. +system_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question. Return any relevant text verbatim. ______________________ -{context}""" +{context}""" # noqa: E501 messages = [ SystemMessagePromptTemplate.from_template(system_template), HumanMessagePromptTemplate.from_template("{question}"), @@ -30,7 +30,7 @@ QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector( default_prompt=QUESTION_PROMPT, conditionals=[(is_chat_model, CHAT_QUESTION_PROMPT)] ) -combine_prompt_template = """Given the following extracted parts of a long document and a question, create a final answer. +combine_prompt_template = """Given the following extracted parts of a long document and a question, create a final answer. If you don't know the answer, just say that you don't know. Don't try to make up an answer. QUESTION: Which state/country's law governs the interpretation of the contract? @@ -45,13 +45,13 @@ FINAL ANSWER: This Agreement is governed by English law. QUESTION: What did the president say about Michael Jackson? ========= -Content: Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n\nLast year COVID-19 kept us apart. This year we are finally together again. \n\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n\nWith a duty to one another to the American people to the Constitution. \n\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \n\nSix days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n\nHe met the Ukrainian people. \n\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. +Content: Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n\nLast year COVID-19 kept us apart. This year we are finally together again. \n\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n\nWith a duty to one another to the American people to the Constitution. \n\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \n\nSix days ago, Russia's Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n\nHe met the Ukrainian people. \n\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. -Content: And we won’t stop. \n\nWe have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. \n\nLet’s use this moment to reset. Let’s stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. \n\nLet’s stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. \n\nWe can’t change how divided we’ve been. But we can change how we move forward—on COVID-19 and other issues we must face together. \n\nI recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. \n\nThey were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n\nOfficer Mora was 27 years old. \n\nOfficer Rivera was 22. \n\nBoth Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. \n\nI spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. +Content: And we won't stop. \n\nWe have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. \n\nLet's use this moment to reset. Let's stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. \n\nLet's stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. \n\nWe can't change how divided we've been. But we can change how we move forward—on COVID-19 and other issues we must face together. \n\nI recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. \n\nThey were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n\nOfficer Mora was 27 years old. \n\nOfficer Rivera was 22. \n\nBoth Dominican Americans who'd grown up on the same streets they later chose to patrol as police officers. \n\nI spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. -Content: And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. \n\nTo all Americans, I will be honest with you, as I’ve always promised. A Russian dictator, invading a foreign country, has costs around the world. \n\nAnd I’m taking robust action to make sure the pain of our sanctions is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers. \n\nTonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. \n\nAmerica will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. \n\nThese steps will help blunt gas prices here at home. And I know the news about what’s happening can seem alarming. \n\nBut I want you to know that we are going to be okay. +Content: And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. \n\nTo all Americans, I will be honest with you, as I've always promised. A Russian dictator, invading a foreign country, has costs around the world. \n\nAnd I'm taking robust action to make sure the pain of our sanctions is targeted at Russia's economy. And I will use every tool at our disposal to protect American businesses and consumers. \n\nTonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. \n\nAmerica will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. \n\nThese steps will help blunt gas prices here at home. And I know the news about what's happening can seem alarming. \n\nBut I want you to know that we are going to be okay. -Content: More support for patients and families. \n\nTo get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. \n\nIt’s based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. \n\nARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer’s, diabetes, and more. \n\nA unity agenda for the nation. \n\nWe can do this. \n\nMy fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. \n\nIn this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. \n\nWe have fought for freedom, expanded liberty, defeated totalitarianism and terror. \n\nAnd built the strongest, freest, and most prosperous nation the world has ever known. \n\nNow is the hour. \n\nOur moment of responsibility. \n\nOur test of resolve and conscience, of history itself. \n\nIt is in this moment that our character is formed. Our purpose is found. Our future is forged. \n\nWell I know this nation. +Content: More support for patients and families. \n\nTo get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. \n\nIt's based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. \n\nARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer's, diabetes, and more. \n\nA unity agenda for the nation. \n\nWe can do this. \n\nMy fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. \n\nIn this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. \n\nWe have fought for freedom, expanded liberty, defeated totalitarianism and terror. \n\nAnd built the strongest, freest, and most prosperous nation the world has ever known. \n\nNow is the hour. \n\nOur moment of responsibility. \n\nOur test of resolve and conscience, of history itself. \n\nIt is in this moment that our character is formed. Our purpose is found. Our future is forged. \n\nWell I know this nation. ========= FINAL ANSWER: The president did not mention Michael Jackson. @@ -59,15 +59,15 @@ QUESTION: {question} ========= {summaries} ========= -FINAL ANSWER:""" +FINAL ANSWER:""" # noqa: E501 COMBINE_PROMPT = PromptTemplate( template=combine_prompt_template, input_variables=["summaries", "question"] ) -system_template = """Given the following extracted parts of a long document and a question, create a final answer. +system_template = """Given the following extracted parts of a long document and a question, create a final answer. If you don't know the answer, just say that you don't know. Don't try to make up an answer. ______________________ -{summaries}""" +{summaries}""" # noqa: E501 messages = [ SystemMessagePromptTemplate.from_template(system_template), HumanMessagePromptTemplate.from_template("{question}"), diff --git a/libs/langchain/langchain/chains/question_answering/map_rerank_prompt.py b/libs/langchain/langchain/chains/question_answering/map_rerank_prompt.py index f9547385f92..e2c62785723 100644 --- a/libs/langchain/langchain/chains/question_answering/map_rerank_prompt.py +++ b/libs/langchain/langchain/chains/question_answering/map_rerank_prompt.py @@ -1,7 +1,7 @@ -# flake8: noqa -from langchain.output_parsers.regex import RegexParser from langchain_core.prompts import PromptTemplate +from langchain.output_parsers.regex import RegexParser + output_parser = RegexParser( regex=r"(.*?)\nScore: (\d*)", output_keys=["answer", "score"], @@ -58,7 +58,7 @@ Context: {context} --------- Question: {question} -Helpful Answer:""" +Helpful Answer:""" # noqa: E501 PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"], diff --git a/libs/langchain/langchain/chains/question_answering/refine_prompts.py b/libs/langchain/langchain/chains/question_answering/refine_prompts.py index 303d94ec560..459a3298352 100644 --- a/libs/langchain/langchain/chains/question_answering/refine_prompts.py +++ b/libs/langchain/langchain/chains/question_answering/refine_prompts.py @@ -1,13 +1,10 @@ -# flake8: noqa -from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model from langchain_core.prompts.chat import ( - AIMessagePromptTemplate, ChatPromptTemplate, - HumanMessagePromptTemplate, - SystemMessagePromptTemplate, ) from langchain_core.prompts.prompt import PromptTemplate +from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model + DEFAULT_REFINE_PROMPT_TMPL = ( "The original question is as follows: {question}\n" "We have provided an existing answer: {existing_answer}\n" diff --git a/libs/langchain/langchain/chains/question_answering/stuff_prompt.py b/libs/langchain/langchain/chains/question_answering/stuff_prompt.py index ee006433852..5f52c87d940 100644 --- a/libs/langchain/langchain/chains/question_answering/stuff_prompt.py +++ b/libs/langchain/langchain/chains/question_answering/stuff_prompt.py @@ -1,5 +1,3 @@ -# flake8: noqa -from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model from langchain_core.prompts import PromptTemplate from langchain_core.prompts.chat import ( ChatPromptTemplate, @@ -7,20 +5,22 @@ from langchain_core.prompts.chat import ( SystemMessagePromptTemplate, ) +from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model + prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. {context} Question: {question} -Helpful Answer:""" +Helpful Answer:""" # noqa: E501 PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) -system_template = """Use the following pieces of context to answer the user's question. +system_template = """Use the following pieces of context to answer the user's question. If you don't know the answer, just say that you don't know, don't try to make up an answer. ---------------- -{context}""" +{context}""" # noqa: E501 messages = [ SystemMessagePromptTemplate.from_template(system_template), HumanMessagePromptTemplate.from_template("{question}"), diff --git a/libs/langchain/langchain/chains/retrieval_qa/base.py b/libs/langchain/langchain/chains/retrieval_qa/base.py index a6a7da9fe2e..2409eb6218c 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/base.py +++ b/libs/langchain/langchain/chains/retrieval_qa/base.py @@ -84,10 +84,14 @@ class BaseRetrievalQA(Chain): """Initialize from LLM.""" _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) llm_chain = LLMChain( - llm=llm, prompt=_prompt, callbacks=callbacks, **(llm_chain_kwargs or {}) + llm=llm, + prompt=_prompt, + callbacks=callbacks, + **(llm_chain_kwargs or {}), ) document_prompt = PromptTemplate( - input_variables=["page_content"], template="Context:\n{page_content}" + input_variables=["page_content"], + template="Context:\n{page_content}", ) combine_documents_chain = StuffDocumentsChain( llm_chain=llm_chain, @@ -113,7 +117,9 @@ class BaseRetrievalQA(Chain): """Load chain from chain type.""" _chain_type_kwargs = chain_type_kwargs or {} combine_documents_chain = load_qa_chain( - llm, chain_type=chain_type, **_chain_type_kwargs + llm, + chain_type=chain_type, + **_chain_type_kwargs, ) return cls(combine_documents_chain=combine_documents_chain, **kwargs) @@ -152,7 +158,9 @@ class BaseRetrievalQA(Chain): else: docs = self._get_docs(question) # type: ignore[call-arg] answer = self.combine_documents_chain.run( - input_documents=docs, question=question, callbacks=_run_manager.get_child() + input_documents=docs, + question=question, + callbacks=_run_manager.get_child(), ) if self.return_source_documents: @@ -194,7 +202,9 @@ class BaseRetrievalQA(Chain): else: docs = await self._aget_docs(question) # type: ignore[call-arg] answer = await self.combine_documents_chain.arun( - input_documents=docs, question=question, callbacks=_run_manager.get_child() + input_documents=docs, + question=question, + callbacks=_run_manager.get_child(), ) if self.return_source_documents: @@ -267,7 +277,8 @@ class RetrievalQA(BaseRetrievalQA): ) -> list[Document]: """Get docs.""" return self.retriever.invoke( - question, config={"callbacks": run_manager.get_child()} + question, + config={"callbacks": run_manager.get_child()}, ) async def _aget_docs( @@ -278,7 +289,8 @@ class RetrievalQA(BaseRetrievalQA): ) -> list[Document]: """Get docs.""" return await self.retriever.ainvoke( - question, config={"callbacks": run_manager.get_child()} + question, + config={"callbacks": run_manager.get_child()}, ) @property @@ -313,7 +325,7 @@ class VectorDBQA(BaseRetrievalQA): def raise_deprecation(cls, values: dict) -> Any: warnings.warn( "`VectorDBQA` is deprecated - " - "please use `from langchain.chains import RetrievalQA`" + "please use `from langchain.chains import RetrievalQA`", ) return values @@ -337,11 +349,15 @@ class VectorDBQA(BaseRetrievalQA): """Get docs.""" if self.search_type == "similarity": docs = self.vectorstore.similarity_search( - question, k=self.k, **self.search_kwargs + question, + k=self.k, + **self.search_kwargs, ) elif self.search_type == "mmr": docs = self.vectorstore.max_marginal_relevance_search( - question, k=self.k, **self.search_kwargs + question, + k=self.k, + **self.search_kwargs, ) else: msg = f"search_type of {self.search_type} not allowed." diff --git a/libs/langchain/langchain/chains/retrieval_qa/prompt.py b/libs/langchain/langchain/chains/retrieval_qa/prompt.py index 963c353184e..1b588038a62 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/prompt.py +++ b/libs/langchain/langchain/chains/retrieval_qa/prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts import PromptTemplate prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. @@ -6,7 +5,7 @@ prompt_template = """Use the following pieces of context to answer the question {context} Question: {question} -Helpful Answer:""" +Helpful Answer:""" # noqa: E501 PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) diff --git a/libs/langchain/langchain/chains/router/base.py b/libs/langchain/langchain/chains/router/base.py index 2091cf5e5ee..ce280b7502a 100644 --- a/libs/langchain/langchain/chains/router/base.py +++ b/libs/langchain/langchain/chains/router/base.py @@ -43,7 +43,9 @@ class RouterChain(Chain, ABC): return Route(result["destination"], result["next_inputs"]) async def aroute( - self, inputs: dict[str, Any], callbacks: Callbacks = None + self, + inputs: dict[str, Any], + callbacks: Callbacks = None, ) -> Route: result = await self.acall(inputs, callbacks=callbacks) return Route(result["destination"], result["next_inputs"]) @@ -93,13 +95,15 @@ class MultiRouteChain(Chain): route = self.router_chain.route(inputs, callbacks=callbacks) _run_manager.on_text( - str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose + str(route.destination) + ": " + str(route.next_inputs), + verbose=self.verbose, ) if not route.destination: return self.default_chain(route.next_inputs, callbacks=callbacks) if route.destination in self.destination_chains: return self.destination_chains[route.destination]( - route.next_inputs, callbacks=callbacks + route.next_inputs, + callbacks=callbacks, ) if self.silent_errors: return self.default_chain(route.next_inputs, callbacks=callbacks) @@ -116,19 +120,23 @@ class MultiRouteChain(Chain): route = await self.router_chain.aroute(inputs, callbacks=callbacks) await _run_manager.on_text( - str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose + str(route.destination) + ": " + str(route.next_inputs), + verbose=self.verbose, ) if not route.destination: return await self.default_chain.acall( - route.next_inputs, callbacks=callbacks + route.next_inputs, + callbacks=callbacks, ) if route.destination in self.destination_chains: return await self.destination_chains[route.destination].acall( - route.next_inputs, callbacks=callbacks + route.next_inputs, + callbacks=callbacks, ) if self.silent_errors: return await self.default_chain.acall( - route.next_inputs, callbacks=callbacks + route.next_inputs, + callbacks=callbacks, ) msg = f"Received invalid destination chain name '{route.destination}'" raise ValueError(msg) diff --git a/libs/langchain/langchain/chains/router/embedding_router.py b/libs/langchain/langchain/chains/router/embedding_router.py index aaa4af6b167..4e9cc94e012 100644 --- a/libs/langchain/langchain/chains/router/embedding_router.py +++ b/libs/langchain/langchain/chains/router/embedding_router.py @@ -63,10 +63,12 @@ class EmbeddingRouterChain(RouterChain): """Convenience constructor.""" documents = [] for name, descriptions in names_and_descriptions: - for description in descriptions: - documents.append( + documents.extend( + [ Document(page_content=description, metadata={"name": name}) - ) + for description in descriptions + ] + ) vectorstore = vectorstore_cls.from_documents(documents, embeddings) return cls(vectorstore=vectorstore, **kwargs) @@ -80,10 +82,12 @@ class EmbeddingRouterChain(RouterChain): ) -> EmbeddingRouterChain: """Convenience constructor.""" documents = [] - for name, descriptions in names_and_descriptions: - for description in descriptions: - documents.append( - Document(page_content=description, metadata={"name": name}) - ) + documents.extend( + [ + Document(page_content=description, metadata={"name": name}) + for name, descriptions in names_and_descriptions + for description in descriptions + ] + ) vectorstore = await vectorstore_cls.afrom_documents(documents, embeddings) return cls(vectorstore=vectorstore, **kwargs) diff --git a/libs/langchain/langchain/chains/router/llm_router.py b/libs/langchain/langchain/chains/router/llm_router.py index b62845d0f5d..cfc4de67557 100644 --- a/libs/langchain/langchain/chains/router/llm_router.py +++ b/libs/langchain/langchain/chains/router/llm_router.py @@ -155,7 +155,10 @@ class LLMRouterChain(RouterChain): @classmethod def from_llm( - cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any + cls, + llm: BaseLanguageModel, + prompt: BasePromptTemplate, + **kwargs: Any, ) -> LLMRouterChain: """Convenience constructor.""" llm_chain = LLMChain(llm=llm, prompt=prompt) diff --git a/libs/langchain/langchain/chains/router/multi_prompt.py b/libs/langchain/langchain/chains/router/multi_prompt.py index c57b63949b6..c0b85a1c6cd 100644 --- a/libs/langchain/langchain/chains/router/multi_prompt.py +++ b/libs/langchain/langchain/chains/router/multi_prompt.py @@ -157,7 +157,7 @@ class MultiPromptChain(MultiRouteChain): destinations = [f"{p['name']}: {p['description']}" for p in prompt_infos] destinations_str = "\n".join(destinations) router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format( - destinations=destinations_str + destinations=destinations_str, ) router_prompt = PromptTemplate( template=router_template, diff --git a/libs/langchain/langchain/chains/router/multi_retrieval_qa.py b/libs/langchain/langchain/chains/router/multi_retrieval_qa.py index 2eb3bdd4f07..8745b48a582 100644 --- a/libs/langchain/langchain/chains/router/multi_retrieval_qa.py +++ b/libs/langchain/langchain/chains/router/multi_retrieval_qa.py @@ -56,7 +56,7 @@ class MultiRetrievalQAChain(MultiRouteChain): destinations = [f"{r['name']}: {r['description']}" for r in retriever_infos] destinations_str = "\n".join(destinations) router_template = MULTI_RETRIEVAL_ROUTER_TEMPLATE.format( - destinations=destinations_str + destinations=destinations_str, ) router_prompt = PromptTemplate( template=router_template, @@ -75,12 +75,15 @@ class MultiRetrievalQAChain(MultiRouteChain): _default_chain = default_chain elif default_retriever: _default_chain = RetrievalQA.from_llm( - llm, prompt=default_prompt, retriever=default_retriever + llm, + prompt=default_prompt, + retriever=default_retriever, ) else: prompt_template = DEFAULT_TEMPLATE.replace("input", "query") prompt = PromptTemplate( - template=prompt_template, input_variables=["history", "query"] + template=prompt_template, + input_variables=["history", "query"], ) if default_chain_llm is None: msg = ( diff --git a/libs/langchain/langchain/chains/sequential.py b/libs/langchain/langchain/chains/sequential.py index 43e6515027a..bda909fc851 100644 --- a/libs/langchain/langchain/chains/sequential.py +++ b/libs/langchain/langchain/chains/sequential.py @@ -118,7 +118,9 @@ class SequentialChain(Chain): callbacks = _run_manager.get_child() for i, chain in enumerate(self.chains): outputs = await chain.acall( - known_values, return_only_outputs=True, callbacks=callbacks + known_values, + return_only_outputs=True, + callbacks=callbacks, ) known_values.update(outputs) return {k: known_values[k] for k in self.output_variables} @@ -181,12 +183,16 @@ class SimpleSequentialChain(Chain): color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) for i, chain in enumerate(self.chains): _input = chain.run( - _input, callbacks=_run_manager.get_child(f"step_{i + 1}") + _input, + callbacks=_run_manager.get_child(f"step_{i + 1}"), ) if self.strip_outputs: _input = _input.strip() _run_manager.on_text( - _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose + _input, + color=color_mapping[str(i)], + end="\n", + verbose=self.verbose, ) return {self.output_key: _input} @@ -200,11 +206,15 @@ class SimpleSequentialChain(Chain): color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) for i, chain in enumerate(self.chains): _input = await chain.arun( - _input, callbacks=_run_manager.get_child(f"step_{i + 1}") + _input, + callbacks=_run_manager.get_child(f"step_{i + 1}"), ) if self.strip_outputs: _input = _input.strip() await _run_manager.on_text( - _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose + _input, + color=color_mapping[str(i)], + end="\n", + verbose=self.verbose, ) return {self.output_key: _input} diff --git a/libs/langchain/langchain/chains/sql_database/prompt.py b/libs/langchain/langchain/chains/sql_database/prompt.py index 7efe824cbf6..6a072d0dd05 100644 --- a/libs/langchain/langchain/chains/sql_database/prompt.py +++ b/libs/langchain/langchain/chains/sql_database/prompt.py @@ -1,8 +1,6 @@ -# flake8: noqa from langchain_core.output_parsers.list import CommaSeparatedListOutputParser from langchain_core.prompts.prompt import PromptTemplate - PROMPT_SUFFIX = """Only use the following tables: {table_info} @@ -21,7 +19,7 @@ SQLQuery: SQL Query to run SQLResult: Result of the SQLQuery Answer: Final answer here -""" +""" # noqa: E501 PROMPT = PromptTemplate( input_variables=["input", "table_info", "dialect", "top_k"], @@ -35,7 +33,7 @@ Question: {query} Table Names: {table_names} -Relevant Table Names:""" +Relevant Table Names:""" # noqa: E501 DECIDER_PROMPT = PromptTemplate( input_variables=["query", "table_names"], template=_DECIDER_TEMPLATE, @@ -46,7 +44,7 @@ _cratedb_prompt = """You are a CrateDB expert. Given an input question, first cr Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per CrateDB. You can order the results to return the most informative data in the database. Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. -Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today". +Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today". Use the following format: @@ -55,7 +53,7 @@ SQLQuery: SQL Query to run SQLResult: Result of the SQLQuery Answer: Final answer here -""" +""" # noqa: E501 CRATEDB_PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], @@ -75,7 +73,7 @@ SQLQuery: SQL Query to run SQLResult: Result of the SQLQuery Answer: Final answer here -""" +""" # noqa: E501 DUCKDB_PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], @@ -95,7 +93,7 @@ SQLQuery: SQL Query to run SQLResult: Result of the SQLQuery Answer: Final answer here -""" +""" # noqa: E501 GOOGLESQL_PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], @@ -116,7 +114,7 @@ SQLQuery: SQL Query to run SQLResult: Result of the SQLQuery Answer: Final answer here -""" +""" # noqa: E501 MSSQL_PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], @@ -137,7 +135,7 @@ SQLQuery: SQL Query to run SQLResult: Result of the SQLQuery Answer: Final answer here -""" +""" # noqa: E501 MYSQL_PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], @@ -158,7 +156,7 @@ SQLQuery: SQL Query to run SQLResult: Result of the SQLQuery Answer: Final answer here -""" +""" # noqa: E501 MARIADB_PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], @@ -179,7 +177,7 @@ SQLQuery: SQL Query to run SQLResult: Result of the SQLQuery Answer: Final answer here -""" +""" # noqa: E501 ORACLE_PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], @@ -200,7 +198,7 @@ SQLQuery: SQL Query to run SQLResult: Result of the SQLQuery Answer: Final answer here -""" +""" # noqa: E501 POSTGRES_PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], @@ -221,7 +219,7 @@ SQLQuery: SQL Query to run SQLResult: Result of the SQLQuery Answer: Final answer here -""" +""" # noqa: E501 SQLITE_PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], @@ -241,7 +239,7 @@ SQLQuery: "SQL Query to run" SQLResult: "Result of the SQLQuery" Answer: "Final answer here" -""" +""" # noqa: E501 CLICKHOUSE_PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], @@ -261,7 +259,7 @@ SQLQuery: "SQL Query to run" SQLResult: "Result of the SQLQuery" Answer: "Final answer here" -""" +""" # noqa: E501 PRESTODB_PROMPT = PromptTemplate( input_variables=["input", "table_info", "top_k"], diff --git a/libs/langchain/langchain/chains/sql_database/query.py b/libs/langchain/langchain/chains/sql_database/query.py index 09e8e12319b..ef3ccce7155 100644 --- a/libs/langchain/langchain/chains/sql_database/query.py +++ b/libs/langchain/langchain/chains/sql_database/query.py @@ -121,7 +121,7 @@ def create_sql_query_chain( else: prompt_to_use = PROMPT if {"input", "top_k", "table_info"}.difference( - prompt_to_use.input_variables + list(prompt_to_use.partial_variables) + prompt_to_use.input_variables + list(prompt_to_use.partial_variables), ): msg = ( f"Prompt must have input variables: 'input', 'top_k', " diff --git a/libs/langchain/langchain/chains/structured_output/base.py b/libs/langchain/langchain/chains/structured_output/base.py index 8e64a268f92..88cccc5fea8 100644 --- a/libs/langchain/langchain/chains/structured_output/base.py +++ b/libs/langchain/langchain/chains/structured_output/base.py @@ -190,7 +190,9 @@ def create_structured_output_runnable( enforce_function_usage: bool = True, return_single: bool = True, mode: Literal[ - "openai-functions", "openai-tools", "openai-json" + "openai-functions", + "openai-tools", + "openai-json", ] = "openai-functions", **kwargs: Any, ) -> Runnable: @@ -391,7 +393,8 @@ def create_structured_output_runnable( """ # noqa: E501 # for backwards compatibility force_function_usage = kwargs.get( - "enforce_single_function_usage", enforce_function_usage + "enforce_single_function_usage", + enforce_function_usage, ) if mode == "openai-tools": @@ -428,7 +431,11 @@ def create_structured_output_runnable( ) raise ValueError(msg) return _create_openai_json_runnable( - output_schema, llm, prompt=prompt, output_parser=output_parser, **kwargs + output_schema, + llm, + prompt=prompt, + output_parser=output_parser, + **kwargs, ) msg = ( f"Invalid mode {mode}. Expected one of 'openai-tools', 'openai-functions', " @@ -454,7 +461,8 @@ def _create_openai_tools_runnable( "function": {"name": oai_tool["function"]["name"]}, } output_parser = output_parser or _get_openai_tool_output_parser( - tool, first_tool_only=first_tool_only + tool, + first_tool_only=first_tool_only, ) if prompt: return prompt | llm.bind(**llm_kwargs) | output_parser @@ -473,7 +481,8 @@ def _get_openai_tool_output_parser( else: key_name = convert_to_openai_tool(tool)["function"]["name"] output_parser = JsonOutputKeyToolsParser( - first_tool_only=first_tool_only, key_name=key_name + first_tool_only=first_tool_only, + key_name=key_name, ) return output_parser @@ -561,7 +570,8 @@ def _create_openai_functions_structured_output_runnable( function = _OutputFormatter output_parser = output_parser or PydanticAttrOutputFunctionsParser( - pydantic_schema=_OutputFormatter, attr_name="output" + pydantic_schema=_OutputFormatter, + attr_name="output", ) return create_openai_fn_runnable( [function], diff --git a/libs/langchain/langchain/chains/summarize/chain.py b/libs/langchain/langchain/chains/summarize/chain.py index 1675292b61b..acced32963e 100644 --- a/libs/langchain/langchain/chains/summarize/chain.py +++ b/libs/langchain/langchain/chains/summarize/chain.py @@ -20,7 +20,9 @@ class LoadingCallable(Protocol): """Interface for loading the combine documents chain.""" def __call__( - self, llm: BaseLanguageModel, **kwargs: Any + self, + llm: BaseLanguageModel, + **kwargs: Any, ) -> BaseCombineDocumentsChain: """Callable to load the combine documents chain.""" diff --git a/libs/langchain/langchain/chains/summarize/map_reduce_prompt.py b/libs/langchain/langchain/chains/summarize/map_reduce_prompt.py index 3cf06395c63..72031fd421e 100644 --- a/libs/langchain/langchain/chains/summarize/map_reduce_prompt.py +++ b/libs/langchain/langchain/chains/summarize/map_reduce_prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts import PromptTemplate prompt_template = """Write a concise summary of the following: diff --git a/libs/langchain/langchain/chains/summarize/stuff_prompt.py b/libs/langchain/langchain/chains/summarize/stuff_prompt.py index 3cf06395c63..72031fd421e 100644 --- a/libs/langchain/langchain/chains/summarize/stuff_prompt.py +++ b/libs/langchain/langchain/chains/summarize/stuff_prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts import PromptTemplate prompt_template = """Write a concise summary of the following: diff --git a/libs/langchain/langchain/chains/transform.py b/libs/langchain/langchain/chains/transform.py index 004d4a43efc..a241b8dc165 100644 --- a/libs/langchain/langchain/chains/transform.py +++ b/libs/langchain/langchain/chains/transform.py @@ -79,6 +79,6 @@ class TransformChain(Chain): return await self.atransform_cb(inputs) self._log_once( "TransformChain's atransform is not provided, falling" - " back to synchronous transform" + " back to synchronous transform", ) return self.transform_cb(inputs) diff --git a/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py b/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py index e9db4d9bc6c..eda066ba287 100644 --- a/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py +++ b/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "QianfanChatEndpoint": "langchain_community.chat_models.baidu_qianfan_endpoint" + "QianfanChatEndpoint": "langchain_community.chat_models.baidu_qianfan_endpoint", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 2c465d97567..9a5765d46a8 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -315,12 +315,14 @@ def init_chat_model( warnings.warn( f"{config_prefix=} has been set but no fields are configurable. Set " f"`configurable_fields=(...)` to specify the model params that are " - f"configurable." + f"configurable.", ) if not configurable_fields: return _init_chat_model_helper( - cast(str, model), model_provider=model_provider, **kwargs + cast(str, model), + model_provider=model_provider, + **kwargs, ) if model: kwargs["model"] = model @@ -334,7 +336,10 @@ def init_chat_model( def _init_chat_model_helper( - model: str, *, model_provider: Optional[str] = None, **kwargs: Any + model: str, + *, + model_provider: Optional[str] = None, + **kwargs: Any, ) -> BaseChatModel: model, model_provider = _parse_model(model, model_provider) if model_provider == "openai": @@ -566,7 +571,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): else config_prefix ) self._queued_declarative_operations: list[tuple[str, tuple, dict]] = list( - queued_declarative_operations + queued_declarative_operations, ) def __getattr__(self, name: str) -> Any: @@ -578,7 +583,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): # self._model()). def queue(*args: Any, **kwargs: Any) -> _ConfigurableModel: queued_declarative_operations = list( - self._queued_declarative_operations + self._queued_declarative_operations, ) queued_declarative_operations.append((name, args, kwargs)) return _ConfigurableModel( @@ -640,7 +645,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): "with_config", (), {"config": remaining_config}, - ) + ), ) return _ConfigurableModel( default_config={**self._default_config, **model_params}, @@ -719,12 +724,18 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): if isinstance(config, list): config = config[0] return self._model(config).batch( - inputs, config=config, return_exceptions=return_exceptions, **kwargs + inputs, + config=config, + return_exceptions=return_exceptions, + **kwargs, ) # If multiple configs default to Runnable.batch which uses executor to invoke # in parallel. return super().batch( - inputs, config=config, return_exceptions=return_exceptions, **kwargs + inputs, + config=config, + return_exceptions=return_exceptions, + **kwargs, ) async def abatch( @@ -741,12 +752,18 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): if isinstance(config, list): config = config[0] return await self._model(config).abatch( - inputs, config=config, return_exceptions=return_exceptions, **kwargs + inputs, + config=config, + return_exceptions=return_exceptions, + **kwargs, ) # If multiple configs default to Runnable.batch which uses executor to invoke # in parallel. return await super().abatch( - inputs, config=config, return_exceptions=return_exceptions, **kwargs + inputs, + config=config, + return_exceptions=return_exceptions, + **kwargs, ) def batch_as_completed( @@ -763,13 +780,19 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): if isinstance(config, list): config = config[0] yield from self._model(cast(RunnableConfig, config)).batch_as_completed( # type: ignore[call-overload] - inputs, config=config, return_exceptions=return_exceptions, **kwargs + inputs, + config=config, + return_exceptions=return_exceptions, + **kwargs, ) # If multiple configs default to Runnable.batch which uses executor to invoke # in parallel. else: yield from super().batch_as_completed( # type: ignore[call-overload] - inputs, config=config, return_exceptions=return_exceptions, **kwargs + inputs, + config=config, + return_exceptions=return_exceptions, + **kwargs, ) async def abatch_as_completed( @@ -786,16 +809,22 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): if isinstance(config, list): config = config[0] async for x in self._model( - cast(RunnableConfig, config) + cast(RunnableConfig, config), ).abatch_as_completed( # type: ignore[call-overload] - inputs, config=config, return_exceptions=return_exceptions, **kwargs + inputs, + config=config, + return_exceptions=return_exceptions, + **kwargs, ): yield x # If multiple configs default to Runnable.batch which uses executor to invoke # in parallel. else: async for x in super().abatch_as_completed( # type: ignore[call-overload] - inputs, config=config, return_exceptions=return_exceptions, **kwargs + inputs, + config=config, + return_exceptions=return_exceptions, + **kwargs, ): yield x @@ -922,6 +951,8 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): # Explicitly added to satisfy downstream linters. def with_structured_output( - self, schema: Union[dict, type[BaseModel]], **kwargs: Any + self, + schema: Union[dict, type[BaseModel]], + **kwargs: Any, ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: return self.__getattr__("with_structured_output")(schema, **kwargs) diff --git a/libs/langchain/langchain/chat_models/meta.py b/libs/langchain/langchain/chat_models/meta.py index 22fa97e08c4..72b0e623f86 100644 --- a/libs/langchain/langchain/chat_models/meta.py +++ b/libs/langchain/langchain/chat_models/meta.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "convert_messages_to_prompt_llama": "langchain_community.chat_models.meta" + "convert_messages_to_prompt_llama": "langchain_community.chat_models.meta", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/chat_models/pai_eas_endpoint.py b/libs/langchain/langchain/chat_models/pai_eas_endpoint.py index 2c3134f63e6..13a166c2ca2 100644 --- a/libs/langchain/langchain/chat_models/pai_eas_endpoint.py +++ b/libs/langchain/langchain/chat_models/pai_eas_endpoint.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "PaiEasChatEndpoint": "langchain_community.chat_models.pai_eas_endpoint" + "PaiEasChatEndpoint": "langchain_community.chat_models.pai_eas_endpoint", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/chat_models/promptlayer_openai.py b/libs/langchain/langchain/chat_models/promptlayer_openai.py index b7a3bb0ddf1..1f5c638982e 100644 --- a/libs/langchain/langchain/chat_models/promptlayer_openai.py +++ b/libs/langchain/langchain/chat_models/promptlayer_openai.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "PromptLayerChatOpenAI": "langchain_community.chat_models.promptlayer_openai" + "PromptLayerChatOpenAI": "langchain_community.chat_models.promptlayer_openai", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/azure_blob_storage_container.py b/libs/langchain/langchain/document_loaders/azure_blob_storage_container.py index 6301a8d6711..a2d8abd5514 100644 --- a/libs/langchain/langchain/document_loaders/azure_blob_storage_container.py +++ b/libs/langchain/langchain/document_loaders/azure_blob_storage_container.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "AzureBlobStorageContainerLoader": "langchain_community.document_loaders" + "AzureBlobStorageContainerLoader": "langchain_community.document_loaders", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/azure_blob_storage_file.py b/libs/langchain/langchain/document_loaders/azure_blob_storage_file.py index 71e30239c49..a30fb571c3d 100644 --- a/libs/langchain/langchain/document_loaders/azure_blob_storage_file.py +++ b/libs/langchain/langchain/document_loaders/azure_blob_storage_file.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "AzureBlobStorageFileLoader": "langchain_community.document_loaders" + "AzureBlobStorageFileLoader": "langchain_community.document_loaders", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/baiducloud_bos_file.py b/libs/langchain/langchain/document_loaders/baiducloud_bos_file.py index d18c9bfc742..c42c98b1c65 100644 --- a/libs/langchain/langchain/document_loaders/baiducloud_bos_file.py +++ b/libs/langchain/langchain/document_loaders/baiducloud_bos_file.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "BaiduBOSFileLoader": "langchain_community.document_loaders.baiducloud_bos_file" + "BaiduBOSFileLoader": "langchain_community.document_loaders.baiducloud_bos_file", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/college_confidential.py b/libs/langchain/langchain/document_loaders/college_confidential.py index 0d7b04c866f..58efa00c64a 100644 --- a/libs/langchain/langchain/document_loaders/college_confidential.py +++ b/libs/langchain/langchain/document_loaders/college_confidential.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "CollegeConfidentialLoader": "langchain_community.document_loaders" + "CollegeConfidentialLoader": "langchain_community.document_loaders", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/markdown.py b/libs/langchain/langchain/document_loaders/markdown.py index f2d7cce4ca6..45fb3b8eaf2 100644 --- a/libs/langchain/langchain/document_loaders/markdown.py +++ b/libs/langchain/langchain/document_loaders/markdown.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "UnstructuredMarkdownLoader": "langchain_community.document_loaders" + "UnstructuredMarkdownLoader": "langchain_community.document_loaders", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/org_mode.py b/libs/langchain/langchain/document_loaders/org_mode.py index cec468a2574..c017c971866 100644 --- a/libs/langchain/langchain/document_loaders/org_mode.py +++ b/libs/langchain/langchain/document_loaders/org_mode.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "UnstructuredOrgModeLoader": "langchain_community.document_loaders" + "UnstructuredOrgModeLoader": "langchain_community.document_loaders", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/parsers/generic.py b/libs/langchain/langchain/document_loaders/parsers/generic.py index 5d9c6501a64..60cc92cc4b7 100644 --- a/libs/langchain/langchain/document_loaders/parsers/generic.py +++ b/libs/langchain/langchain/document_loaders/parsers/generic.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "MimeTypeBasedParser": "langchain_community.document_loaders.parsers.generic" + "MimeTypeBasedParser": "langchain_community.document_loaders.parsers.generic", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/parsers/html/__init__.py b/libs/langchain/langchain/document_loaders/parsers/html/__init__.py index 107d992a341..ade7424a64c 100644 --- a/libs/langchain/langchain/document_loaders/parsers/html/__init__.py +++ b/libs/langchain/langchain/document_loaders/parsers/html/__init__.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "BS4HTMLParser": "langchain_community.document_loaders.parsers.html.bs4" + "BS4HTMLParser": "langchain_community.document_loaders.parsers.html.bs4", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/parsers/html/bs4.py b/libs/langchain/langchain/document_loaders/parsers/html/bs4.py index 107d992a341..ade7424a64c 100644 --- a/libs/langchain/langchain/document_loaders/parsers/html/bs4.py +++ b/libs/langchain/langchain/document_loaders/parsers/html/bs4.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "BS4HTMLParser": "langchain_community.document_loaders.parsers.html.bs4" + "BS4HTMLParser": "langchain_community.document_loaders.parsers.html.bs4", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/parsers/language/cobol.py b/libs/langchain/langchain/document_loaders/parsers/language/cobol.py index e80b5d65287..e9b64000406 100644 --- a/libs/langchain/langchain/document_loaders/parsers/language/cobol.py +++ b/libs/langchain/langchain/document_loaders/parsers/language/cobol.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "CobolSegmenter": "langchain_community.document_loaders.parsers.language.cobol" + "CobolSegmenter": "langchain_community.document_loaders.parsers.language.cobol", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/parsers/language/python.py b/libs/langchain/langchain/document_loaders/parsers/language/python.py index 79a702792d3..6a5215736ee 100644 --- a/libs/langchain/langchain/document_loaders/parsers/language/python.py +++ b/libs/langchain/langchain/document_loaders/parsers/language/python.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "PythonSegmenter": "langchain_community.document_loaders.parsers.language.python" + "PythonSegmenter": "langchain_community.document_loaders.parsers.language.python", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/parsers/msword.py b/libs/langchain/langchain/document_loaders/parsers/msword.py index 24b418508b4..2b22060bfe6 100644 --- a/libs/langchain/langchain/document_loaders/parsers/msword.py +++ b/libs/langchain/langchain/document_loaders/parsers/msword.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "MsWordParser": "langchain_community.document_loaders.parsers.msword" + "MsWordParser": "langchain_community.document_loaders.parsers.msword", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/parsers/registry.py b/libs/langchain/langchain/document_loaders/parsers/registry.py index 1c1cce71a05..4e94058a2b5 100644 --- a/libs/langchain/langchain/document_loaders/parsers/registry.py +++ b/libs/langchain/langchain/document_loaders/parsers/registry.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "get_parser": "langchain_community.document_loaders.parsers.registry" + "get_parser": "langchain_community.document_loaders.parsers.registry", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/powerpoint.py b/libs/langchain/langchain/document_loaders/powerpoint.py index f86145c20bf..309ccb7d0bd 100644 --- a/libs/langchain/langchain/document_loaders/powerpoint.py +++ b/libs/langchain/langchain/document_loaders/powerpoint.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "UnstructuredPowerPointLoader": "langchain_community.document_loaders" + "UnstructuredPowerPointLoader": "langchain_community.document_loaders", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/pyspark_dataframe.py b/libs/langchain/langchain/document_loaders/pyspark_dataframe.py index 1c870359a52..d6259deef95 100644 --- a/libs/langchain/langchain/document_loaders/pyspark_dataframe.py +++ b/libs/langchain/langchain/document_loaders/pyspark_dataframe.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "PySparkDataFrameLoader": "langchain_community.document_loaders.pyspark_dataframe" + "PySparkDataFrameLoader": "langchain_community.document_loaders.pyspark_dataframe", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_loaders/tencent_cos_directory.py b/libs/langchain/langchain/document_loaders/tencent_cos_directory.py index a456786da43..f0e0871fb53 100644 --- a/libs/langchain/langchain/document_loaders/tencent_cos_directory.py +++ b/libs/langchain/langchain/document_loaders/tencent_cos_directory.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "TencentCOSDirectoryLoader": "langchain_community.document_loaders" + "TencentCOSDirectoryLoader": "langchain_community.document_loaders", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_transformers/beautiful_soup_transformer.py b/libs/langchain/langchain/document_transformers/beautiful_soup_transformer.py index 8028cff805c..b4bc9f0d05a 100644 --- a/libs/langchain/langchain/document_transformers/beautiful_soup_transformer.py +++ b/libs/langchain/langchain/document_transformers/beautiful_soup_transformer.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "BeautifulSoupTransformer": "langchain_community.document_transformers" + "BeautifulSoupTransformer": "langchain_community.document_transformers", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_transformers/doctran_text_extract.py b/libs/langchain/langchain/document_transformers/doctran_text_extract.py index d6fcb4f6388..f3ae78abd1c 100644 --- a/libs/langchain/langchain/document_transformers/doctran_text_extract.py +++ b/libs/langchain/langchain/document_transformers/doctran_text_extract.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "DoctranPropertyExtractor": "langchain_community.document_transformers" + "DoctranPropertyExtractor": "langchain_community.document_transformers", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_transformers/doctran_text_qa.py b/libs/langchain/langchain/document_transformers/doctran_text_qa.py index 21dc7d2b014..1b31dd2b490 100644 --- a/libs/langchain/langchain/document_transformers/doctran_text_qa.py +++ b/libs/langchain/langchain/document_transformers/doctran_text_qa.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "DoctranQATransformer": "langchain_community.document_transformers" + "DoctranQATransformer": "langchain_community.document_transformers", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_transformers/doctran_text_translate.py b/libs/langchain/langchain/document_transformers/doctran_text_translate.py index e0c48863259..1e4364efdd9 100644 --- a/libs/langchain/langchain/document_transformers/doctran_text_translate.py +++ b/libs/langchain/langchain/document_transformers/doctran_text_translate.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "DoctranTextTranslator": "langchain_community.document_transformers" + "DoctranTextTranslator": "langchain_community.document_transformers", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_transformers/google_translate.py b/libs/langchain/langchain/document_transformers/google_translate.py index 7fb2506818c..b110fc5a76d 100644 --- a/libs/langchain/langchain/document_transformers/google_translate.py +++ b/libs/langchain/langchain/document_transformers/google_translate.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "GoogleTranslateTransformer": "langchain_community.document_transformers" + "GoogleTranslateTransformer": "langchain_community.document_transformers", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_transformers/html2text.py b/libs/langchain/langchain/document_transformers/html2text.py index 307d455955b..117fad22db7 100644 --- a/libs/langchain/langchain/document_transformers/html2text.py +++ b/libs/langchain/langchain/document_transformers/html2text.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "Html2TextTransformer": "langchain_community.document_transformers" + "Html2TextTransformer": "langchain_community.document_transformers", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/document_transformers/nuclia_text_transform.py b/libs/langchain/langchain/document_transformers/nuclia_text_transform.py index 877ec94ea76..be0d706feca 100644 --- a/libs/langchain/langchain/document_transformers/nuclia_text_transform.py +++ b/libs/langchain/langchain/document_transformers/nuclia_text_transform.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "NucliaTextTransformer": "langchain_community.document_transformers" + "NucliaTextTransformer": "langchain_community.document_transformers", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/embeddings/__init__.py b/libs/langchain/langchain/embeddings/__init__.py index 87b20ba4e80..0038281e703 100644 --- a/libs/langchain/langchain/embeddings/__init__.py +++ b/libs/langchain/langchain/embeddings/__init__.py @@ -83,7 +83,7 @@ class HypotheticalDocumentEmbedder: def __init__(self, *args: Any, **kwargs: Any): logger.warning( "Using a deprecated class. Please use " - "`from langchain.chains import HypotheticalDocumentEmbedder` instead" + "`from langchain.chains import HypotheticalDocumentEmbedder` instead", ) from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H @@ -93,7 +93,7 @@ class HypotheticalDocumentEmbedder: def from_llm(cls, *args: Any, **kwargs: Any) -> Any: logger.warning( "Using a deprecated class. Please use " - "`from langchain.chains import HypotheticalDocumentEmbedder` instead" + "`from langchain.chains import HypotheticalDocumentEmbedder` instead", ) from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H diff --git a/libs/langchain/langchain/embeddings/base.py b/libs/langchain/langchain/embeddings/base.py index 40f50418ae6..c0a6c8b891d 100644 --- a/libs/langchain/langchain/embeddings/base.py +++ b/libs/langchain/langchain/embeddings/base.py @@ -79,7 +79,9 @@ def _parse_model_string(model_name: str) -> tuple[str, str]: def _infer_model_and_provider( - model: str, *, provider: Optional[str] = None + model: str, + *, + provider: Optional[str] = None, ) -> tuple[str, str]: if not model.strip(): msg = "Model name cannot be empty" diff --git a/libs/langchain/langchain/embeddings/cache.py b/libs/langchain/langchain/embeddings/cache.py index 2fc73627a4f..2f60d096f4a 100644 --- a/libs/langchain/langchain/embeddings/cache.py +++ b/libs/langchain/langchain/embeddings/cache.py @@ -176,7 +176,7 @@ class CacheBackedEmbeddings(Embeddings): A list of embeddings for the given texts. """ vectors: list[Union[list[float], None]] = self.document_embedding_store.mget( - texts + texts, ) all_missing_indices: list[int] = [ i for i, vector in enumerate(vectors) if vector is None @@ -186,13 +186,14 @@ class CacheBackedEmbeddings(Embeddings): missing_texts = [texts[i] for i in missing_indices] missing_vectors = self.underlying_embeddings.embed_documents(missing_texts) self.document_embedding_store.mset( - list(zip(missing_texts, missing_vectors)) + list(zip(missing_texts, missing_vectors)), ) for index, updated_vector in zip(missing_indices, missing_vectors): vectors[index] = updated_vector return cast( - list[list[float]], vectors + list[list[float]], + vectors, ) # Nones should have been resolved by now async def aembed_documents(self, texts: list[str]) -> list[list[float]]: @@ -220,16 +221,17 @@ class CacheBackedEmbeddings(Embeddings): for missing_indices in batch_iterate(self.batch_size, all_missing_indices): missing_texts = [texts[i] for i in missing_indices] missing_vectors = await self.underlying_embeddings.aembed_documents( - missing_texts + missing_texts, ) await self.document_embedding_store.amset( - list(zip(missing_texts, missing_vectors)) + list(zip(missing_texts, missing_vectors)), ) for index, updated_vector in zip(missing_indices, missing_vectors): vectors[index] = updated_vector return cast( - list[list[float]], vectors + list[list[float]], + vectors, ) # Nones should have been resolved by now def embed_query(self, text: str) -> list[float]: @@ -288,7 +290,8 @@ class CacheBackedEmbeddings(Embeddings): batch_size: Optional[int] = None, query_embedding_cache: Union[bool, ByteStore] = False, key_encoder: Union[ - Callable[[str], str], Literal["sha1", "blake2b", "sha256", "sha512"] + Callable[[str], str], + Literal["sha1", "blake2b", "sha256", "sha512"], ] = "sha1", ) -> CacheBackedEmbeddings: """On-ramp that adds the necessary serialization and encoding to the store. diff --git a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py index d395d317b1a..e31be77d2b8 100644 --- a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py @@ -145,7 +145,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain): eval_chain: LLMChain """The language model chain used for evaluation.""" output_parser: TrajectoryOutputParser = Field( - default_factory=TrajectoryOutputParser + default_factory=TrajectoryOutputParser, ) """The output parser used to parse the output.""" return_reasoning: bool = False # :meta private: @@ -174,7 +174,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain): f"""Tool {i}: {tool.name} Description: {tool.description}""" for i, tool in enumerate(self.agent_tools, 1) - ] + ], ) @staticmethod @@ -199,7 +199,7 @@ Tool used: {action.tool} Tool input: {action.tool_input} Tool output: {output}""" for i, (action, output) in enumerate(steps, 1) - ] + ], ) @staticmethod @@ -296,7 +296,8 @@ The following is the expected answer. Use this to measure correctness: chain_input["tool_descriptions"] = self._tools_description _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() raw_output = self.eval_chain.run( - chain_input, callbacks=_run_manager.get_child() + chain_input, + callbacks=_run_manager.get_child(), ) return cast(dict, self.output_parser.parse(raw_output)) @@ -320,7 +321,8 @@ The following is the expected answer. Use this to measure correctness: chain_input["tool_descriptions"] = self._tools_description _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() raw_output = await self.eval_chain.arun( - chain_input, callbacks=_run_manager.get_child() + chain_input, + callbacks=_run_manager.get_child(), ) return cast(dict, self.output_parser.parse(raw_output)) diff --git a/libs/langchain/langchain/evaluation/agents/trajectory_eval_prompt.py b/libs/langchain/langchain/evaluation/agents/trajectory_eval_prompt.py index adefc6500fb..01127451923 100644 --- a/libs/langchain/langchain/evaluation/agents/trajectory_eval_prompt.py +++ b/libs/langchain/langchain/evaluation/agents/trajectory_eval_prompt.py @@ -1,14 +1,11 @@ """Prompt for trajectory evaluation chain.""" -# flake8: noqa -from langchain_core.messages import HumanMessage, AIMessage, SystemMessage - +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, ) - EVAL_TEMPLATE = """An AI language model has been given access to the following set of tools to help answer a user's question. The tools given to the AI model are: @@ -39,7 +36,7 @@ i. Is the final answer helpful? ii. Does the AI language use a logical sequence of tools to answer the question? iii. Does the AI language model use the tools in a helpful way? iv. Does the AI language model use too many steps to answer the question? -v. Are the appropriate tools used to answer the question?""" +v. Are the appropriate tools used to answer the question?""" # noqa: E501 EXAMPLE_INPUT = """An AI language model has been given access to the following set of tools to help answer a user's question. @@ -84,7 +81,7 @@ i. Is the final answer helpful? ii. Does the AI language use a logical sequence of tools to answer the question? iii. Does the AI language model use the tools in a helpful way? iv. Does the AI language model use too many steps to answer the question? -v. Are the appropriate tools used to answer the question?""" +v. Are the appropriate tools used to answer the question?""" # noqa: E501 EXAMPLE_OUTPUT = """First, let's evaluate the final answer. The final uses good reasoning but is wrong. 2,857 divided by 305 is not 17.5.\ The model should have used the calculator to figure this out. Second does the model use a logical sequence of tools to answer the question?\ @@ -92,10 +89,10 @@ The way model uses the search is not helpful. The model should have used the sea The model didn't use the calculator tool and gave an incorrect answer. The search API should be used for current events or specific questions.\ The tools were not used in a helpful way. The model did not use too many steps to answer the question.\ The model did not use the appropriate tools to answer the question.\ - + Judgment: Given the good reasoning in the final answer but otherwise poor performance, we give the model a score of 2. -Score: 2""" +Score: 2""" # noqa: E501 EVAL_CHAT_PROMPT = ChatPromptTemplate.from_messages( messages=[ @@ -134,7 +131,7 @@ i. Is the final answer helpful? ii. Does the AI language use a logical sequence of tools to answer the question? iii. Does the AI language model use the tools in a helpful way? iv. Does the AI language model use too many steps to answer the question? -v. Are the appropriate tools used to answer the question?""" +v. Are the appropriate tools used to answer the question?""" # noqa: E501 TOOL_FREE_EVAL_CHAT_PROMPT = ChatPromptTemplate.from_messages( diff --git a/libs/langchain/langchain/evaluation/comparison/eval_chain.py b/libs/langchain/langchain/evaluation/comparison/eval_chain.py index 3802d915a5d..6d1d03a465b 100644 --- a/libs/langchain/langchain/evaluation/comparison/eval_chain.py +++ b/libs/langchain/langchain/evaluation/comparison/eval_chain.py @@ -187,7 +187,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain): output_key: str = "results" #: :meta private: output_parser: BaseOutputParser = Field( - default_factory=PairwiseStringResultOutputParser + default_factory=PairwiseStringResultOutputParser, ) @classmethod @@ -259,7 +259,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain): if not hasattr(llm, "model_name") or not llm.model_name.startswith("gpt-4"): logger.warning( "This chain was only tested with GPT-4. \ -Performance may be significantly worse with other models." +Performance may be significantly worse with other models.", ) expected_input_vars = {"prediction", "prediction_b", "input", "criteria"} diff --git a/libs/langchain/langchain/evaluation/comparison/prompt.py b/libs/langchain/langchain/evaluation/comparison/prompt.py index 1eb93ea1a0c..1655f20407c 100644 --- a/libs/langchain/langchain/evaluation/comparison/prompt.py +++ b/libs/langchain/langchain/evaluation/comparison/prompt.py @@ -3,9 +3,8 @@ This prompt is used to compare two responses and evaluate which one best follows the instructions and answers the question. The prompt is based on the paper from Zheng, et. al. https://arxiv.org/abs/2306.05685 -""" +""" # noqa: E501 -# flake8: noqa from langchain_core.prompts.chat import ChatPromptTemplate SYSTEM_MESSAGE = 'Please act as an impartial judge and evaluate the quality \ diff --git a/libs/langchain/langchain/evaluation/criteria/eval_chain.py b/libs/langchain/langchain/evaluation/criteria/eval_chain.py index d8c7d39d124..1f72c73c561 100644 --- a/libs/langchain/langchain/evaluation/criteria/eval_chain.py +++ b/libs/langchain/langchain/evaluation/criteria/eval_chain.py @@ -273,7 +273,8 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain): @classmethod def _resolve_prompt( - cls, prompt: Optional[BasePromptTemplate] = None + cls, + prompt: Optional[BasePromptTemplate] = None, ) -> BasePromptTemplate: expected_input_vars = {"input", "output", "criteria"} prompt_ = prompt or PROMPT @@ -529,7 +530,8 @@ class LabeledCriteriaEvalChain(CriteriaEvalChain): @classmethod def _resolve_prompt( - cls, prompt: Optional[BasePromptTemplate] = None + cls, + prompt: Optional[BasePromptTemplate] = None, ) -> BasePromptTemplate: expected_input_vars = {"input", "output", "criteria", "reference"} prompt_ = prompt or PROMPT_WITH_REFERENCES diff --git a/libs/langchain/langchain/evaluation/criteria/prompt.py b/libs/langchain/langchain/evaluation/criteria/prompt.py index e5ac19fe038..74d4569fb4d 100644 --- a/libs/langchain/langchain/evaluation/criteria/prompt.py +++ b/libs/langchain/langchain/evaluation/criteria/prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa # Credit to https://github.com/openai/evals/tree/main from langchain_core.prompts import PromptTemplate @@ -13,7 +12,7 @@ template = """You are assessing a submitted answer on a given task or input base [Criteria]: {criteria} *** [END DATA] -Does the submission meet the Criteria? First, write out in a step by step manner your reasoning about each criterion to be sure that your conclusion is correct. Avoid simply stating the correct answers at the outset. Then print only the single character "Y" or "N" (without quotes or punctuation) on its own line corresponding to the correct answer of whether the submission meets all criteria. At the end, repeat just the letter again by itself on a new line.""" +Does the submission meet the Criteria? First, write out in a step by step manner your reasoning about each criterion to be sure that your conclusion is correct. Avoid simply stating the correct answers at the outset. Then print only the single character "Y" or "N" (without quotes or punctuation) on its own line corresponding to the correct answer of whether the submission meets all criteria. At the end, repeat just the letter again by itself on a new line.""" # noqa: E501 PROMPT = PromptTemplate( input_variables=["input", "output", "criteria"], template=template @@ -31,7 +30,7 @@ template = """You are assessing a submitted answer on a given task or input base [Reference]: {reference} *** [END DATA] -Does the submission meet the Criteria? First, write out in a step by step manner your reasoning about each criterion to be sure that your conclusion is correct. Avoid simply stating the correct answers at the outset. Then print only the single character "Y" or "N" (without quotes or punctuation) on its own line corresponding to the correct answer of whether the submission meets all criteria. At the end, repeat just the letter again by itself on a new line.""" +Does the submission meet the Criteria? First, write out in a step by step manner your reasoning about each criterion to be sure that your conclusion is correct. Avoid simply stating the correct answers at the outset. Then print only the single character "Y" or "N" (without quotes or punctuation) on its own line corresponding to the correct answer of whether the submission meets all criteria. At the end, repeat just the letter again by itself on a new line.""" # noqa: E501 PROMPT_WITH_REFERENCES = PromptTemplate( input_variables=["input", "output", "criteria", "reference"], template=template diff --git a/libs/langchain/langchain/evaluation/embedding_distance/base.py b/libs/langchain/langchain/evaluation/embedding_distance/base.py index ca374ba5648..9a173eedb4b 100644 --- a/libs/langchain/langchain/evaluation/embedding_distance/base.py +++ b/libs/langchain/langchain/evaluation/embedding_distance/base.py @@ -42,7 +42,7 @@ def _check_numpy() -> bool: "langchain will use a pure Python implementation for embedding distance " "operations, which may significantly impact performance, especially for large " "datasets. For optimal speed and efficiency, consider installing NumPy: " - "pip install numpy" + "pip install numpy", ) return False @@ -346,7 +346,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator): Dict[str, Any]: The computed score. """ vectors = self.embeddings.embed_documents( - [inputs["prediction"], inputs["reference"]] + [inputs["prediction"], inputs["reference"]], ) if _check_numpy(): np = _import_numpy() @@ -373,7 +373,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator): [ inputs["prediction"], inputs["reference"], - ] + ], ) if _check_numpy(): np = _import_numpy() @@ -451,7 +451,8 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator): class PairwiseEmbeddingDistanceEvalChain( - _EmbeddingDistanceChainMixin, PairwiseStringEvaluator + _EmbeddingDistanceChainMixin, + PairwiseStringEvaluator, ): """Use embedding distances to score semantic difference between two predictions. @@ -494,7 +495,7 @@ class PairwiseEmbeddingDistanceEvalChain( [ inputs["prediction"], inputs["prediction_b"], - ] + ], ) if _check_numpy(): np = _import_numpy() @@ -521,7 +522,7 @@ class PairwiseEmbeddingDistanceEvalChain( [ inputs["prediction"], inputs["prediction_b"], - ] + ], ) if _check_numpy(): np = _import_numpy() diff --git a/libs/langchain/langchain/evaluation/loading.py b/libs/langchain/langchain/evaluation/loading.py index 25c0f67c208..2c0a8bac111 100644 --- a/libs/langchain/langchain/evaluation/loading.py +++ b/libs/langchain/langchain/evaluation/loading.py @@ -73,7 +73,8 @@ def load_dataset(uri: str) -> list[dict]: _EVALUATOR_MAP: dict[ - EvaluatorType, Union[type[LLMEvalChain], type[Chain], type[StringEvaluator]] + EvaluatorType, + Union[type[LLMEvalChain], type[Chain], type[StringEvaluator]], ] = { EvaluatorType.QA: QAEvalChain, EvaluatorType.COT_QA: CotQAEvalChain, diff --git a/libs/langchain/langchain/evaluation/qa/eval_prompt.py b/libs/langchain/langchain/evaluation/qa/eval_prompt.py index d29a7858ac2..49f4d2fda2d 100644 --- a/libs/langchain/langchain/evaluation/qa/eval_prompt.py +++ b/libs/langchain/langchain/evaluation/qa/eval_prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts import PromptTemplate template = """You are a teacher grading a quiz. @@ -10,12 +9,12 @@ STUDENT ANSWER: student's answer here TRUE ANSWER: true answer here GRADE: CORRECT or INCORRECT here -Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin! +Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin! QUESTION: {query} STUDENT ANSWER: {result} TRUE ANSWER: {answer} -GRADE:""" +GRADE:""" # noqa: E501 PROMPT = PromptTemplate( input_variables=["query", "result", "answer"], template=template ) @@ -29,12 +28,12 @@ CONTEXT: context the question is about here STUDENT ANSWER: student's answer here GRADE: CORRECT or INCORRECT here -Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin! +Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin! QUESTION: {query} CONTEXT: {context} STUDENT ANSWER: {result} -GRADE:""" +GRADE:""" # noqa: E501 CONTEXT_PROMPT = PromptTemplate( input_variables=["query", "context", "result"], template=context_template ) @@ -51,12 +50,12 @@ STUDENT ANSWER: student's answer here EXPLANATION: step by step reasoning here GRADE: CORRECT or INCORRECT here -Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin! +Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin! QUESTION: {query} CONTEXT: {context} STUDENT ANSWER: {result} -EXPLANATION:""" +EXPLANATION:""" # noqa: E501 COT_PROMPT = PromptTemplate( input_variables=["query", "context", "result"], template=cot_template ) @@ -72,7 +71,7 @@ template = """You are comparing a submitted answer to an expert answer on a give [Submission]: {result} *** [END DATA] -Compare the content and correctness of the submitted SQL with the expert answer. Ignore any differences in whitespace, style, or output column names. The submitted answer may either be correct or incorrect. Determine which case applies. First, explain in detail the similarities or differences between the expert answer and the submission, ignoring superficial aspects such as whitespace, style or output column names. Do not state the final answer in your initial explanation. Then, respond with either "CORRECT" or "INCORRECT" (without quotes or punctuation) on its own line. This should correspond to whether the submitted SQL and the expert answer are semantically the same or different, respectively. Then, repeat your final answer on a new line.""" +Compare the content and correctness of the submitted SQL with the expert answer. Ignore any differences in whitespace, style, or output column names. The submitted answer may either be correct or incorrect. Determine which case applies. First, explain in detail the similarities or differences between the expert answer and the submission, ignoring superficial aspects such as whitespace, style or output column names. Do not state the final answer in your initial explanation. Then, respond with either "CORRECT" or "INCORRECT" (without quotes or punctuation) on its own line. This should correspond to whether the submitted SQL and the expert answer are semantically the same or different, respectively. Then, repeat your final answer on a new line.""" # noqa: E501 SQL_PROMPT = PromptTemplate( input_variables=["query", "answer", "result"], template=template diff --git a/libs/langchain/langchain/evaluation/qa/generate_chain.py b/libs/langchain/langchain/evaluation/qa/generate_chain.py index 94cf36d45a7..41ad4a6ed50 100644 --- a/libs/langchain/langchain/evaluation/qa/generate_chain.py +++ b/libs/langchain/langchain/evaluation/qa/generate_chain.py @@ -13,7 +13,8 @@ from langchain.evaluation.qa.generate_prompt import PROMPT from langchain.output_parsers.regex import RegexParser _QA_OUTPUT_PARSER = RegexParser( - regex=r"QUESTION: (.*?)\n+ANSWER: (.*)", output_keys=["query", "answer"] + regex=r"QUESTION: (.*?)\n+ANSWER: (.*)", + output_keys=["query", "answer"], ) diff --git a/libs/langchain/langchain/evaluation/qa/generate_prompt.py b/libs/langchain/langchain/evaluation/qa/generate_prompt.py index 50dc318b72d..17697361436 100644 --- a/libs/langchain/langchain/evaluation/qa/generate_prompt.py +++ b/libs/langchain/langchain/evaluation/qa/generate_prompt.py @@ -1,8 +1,6 @@ -# flake8: noqa -from langchain.output_parsers.regex import RegexParser from langchain_core.prompts import PromptTemplate -template = """You are a teacher coming up with questions to ask on a quiz. +template = """You are a teacher coming up with questions to ask on a quiz. Given the following document, please generate a question and answer based on that document. Example Format: @@ -16,7 +14,7 @@ These questions should be detailed and be based explicitly on information in the {doc} -""" +""" # noqa: E501 PROMPT = PromptTemplate( input_variables=["doc"], template=template, diff --git a/libs/langchain/langchain/evaluation/schema.py b/libs/langchain/langchain/evaluation/schema.py index f83d8dbeedc..3d664bf80b5 100644 --- a/libs/langchain/langchain/evaluation/schema.py +++ b/libs/langchain/langchain/evaluation/schema.py @@ -221,7 +221,10 @@ class StringEvaluator(_EvalArgsMixin, ABC): """ # noqa: E501 self._check_evaluation_args(reference=reference, input_=input) return self._evaluate_strings( - prediction=prediction, reference=reference, input=input, **kwargs + prediction=prediction, + reference=reference, + input=input, + **kwargs, ) async def aevaluate_strings( @@ -244,7 +247,10 @@ class StringEvaluator(_EvalArgsMixin, ABC): """ # noqa: E501 self._check_evaluation_args(reference=reference, input_=input) return await self._aevaluate_strings( - prediction=prediction, reference=reference, input=input, **kwargs + prediction=prediction, + reference=reference, + input=input, + **kwargs, ) diff --git a/libs/langchain/langchain/evaluation/scoring/eval_chain.py b/libs/langchain/langchain/evaluation/scoring/eval_chain.py index 4ae86e00547..d998f4da900 100644 --- a/libs/langchain/langchain/evaluation/scoring/eval_chain.py +++ b/libs/langchain/langchain/evaluation/scoring/eval_chain.py @@ -175,7 +175,7 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain): output_key: str = "results" #: :meta private: output_parser: BaseOutputParser = Field( - default_factory=ScoreStringResultOutputParser + default_factory=ScoreStringResultOutputParser, ) normalize_by: Optional[float] = None """The value to normalize the score by, if specified.""" @@ -262,7 +262,7 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain): if not (hasattr(llm, "model_name") and not llm.model_name.startswith("gpt-4")): logger.warning( "This chain was only tested with GPT-4. \ -Performance may be significantly worse with other models." +Performance may be significantly worse with other models.", ) expected_input_vars = {"prediction", "input", "criteria"} diff --git a/libs/langchain/langchain/evaluation/scoring/prompt.py b/libs/langchain/langchain/evaluation/scoring/prompt.py index 4181c93a820..3aa6c59b650 100644 --- a/libs/langchain/langchain/evaluation/scoring/prompt.py +++ b/libs/langchain/langchain/evaluation/scoring/prompt.py @@ -5,7 +5,6 @@ and answers the question. The prompt is based on the paper from Zheng, et. al. https://arxiv.org/abs/2306.05685 """ -# flake8: noqa from langchain_core.prompts.chat import ChatPromptTemplate SYSTEM_MESSAGE = "You are a helpful assistant." diff --git a/libs/langchain/langchain/evaluation/string_distance/base.py b/libs/langchain/langchain/evaluation/string_distance/base.py index 6c51c9adef4..406b4b8747f 100644 --- a/libs/langchain/langchain/evaluation/string_distance/base.py +++ b/libs/langchain/langchain/evaluation/string_distance/base.py @@ -149,7 +149,8 @@ class _RapidFuzzChainMixin(Chain): Callable: The distance metric function. """ return _RapidFuzzChainMixin._get_metric( - self.distance, normalize_score=self.normalize_score + self.distance, + normalize_score=self.normalize_score, ) def compute_metric(self, a: str, b: str) -> float: @@ -373,7 +374,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi Dict[str, Any]: The evaluation results containing the score. """ return { - "score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]) + "score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]), } async def _acall( @@ -393,7 +394,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi Dict[str, Any]: The evaluation results containing the score. """ return { - "score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]) + "score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]), } def _evaluate_string_pairs( diff --git a/libs/langchain/langchain/indexes/_sql_record_manager.py b/libs/langchain/langchain/indexes/_sql_record_manager.py index b96af6f4a0b..d8e3af5443e 100644 --- a/libs/langchain/langchain/indexes/_sql_record_manager.py +++ b/libs/langchain/langchain/indexes/_sql_record_manager.py @@ -307,7 +307,7 @@ class SQLRecordManager(RecordManager): # Note: uses SQLite insert to make on_conflict_do_update work. # This code needs to be generalized a bit to work with more dialects. sqlite_insert_stmt: SqliteInsertType = sqlite_insert( - UpsertionRecord + UpsertionRecord, ).values(records_to_upsert) stmt = sqlite_insert_stmt.on_conflict_do_update( [UpsertionRecord.key, UpsertionRecord.namespace], @@ -323,7 +323,7 @@ class SQLRecordManager(RecordManager): # Note: uses postgresql insert to make on_conflict_do_update work. # This code needs to be generalized a bit to work with more dialects. pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values( - records_to_upsert + records_to_upsert, ) stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment] "uix_key_namespace", # Name of constraint @@ -389,7 +389,7 @@ class SQLRecordManager(RecordManager): # Note: uses SQLite insert to make on_conflict_do_update work. # This code needs to be generalized a bit to work with more dialects. sqlite_insert_stmt: SqliteInsertType = sqlite_insert( - UpsertionRecord + UpsertionRecord, ).values(records_to_upsert) stmt = sqlite_insert_stmt.on_conflict_do_update( [UpsertionRecord.key, UpsertionRecord.namespace], @@ -405,7 +405,7 @@ class SQLRecordManager(RecordManager): # Note: uses SQLite insert to make on_conflict_do_update work. # This code needs to be generalized a bit to work with more dialects. pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values( - records_to_upsert + records_to_upsert, ) stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment] "uix_key_namespace", # Name of constraint @@ -429,7 +429,7 @@ class SQLRecordManager(RecordManager): and_( UpsertionRecord.key.in_(keys), UpsertionRecord.namespace == self.namespace, - ) + ), ) records = filtered_query.all() found_keys = {r.key for r in records} @@ -445,8 +445,8 @@ class SQLRecordManager(RecordManager): and_( UpsertionRecord.key.in_(keys), UpsertionRecord.namespace == self.namespace, - ) - ) + ), + ), ) ) .scalars() @@ -467,7 +467,7 @@ class SQLRecordManager(RecordManager): session: Session with self._make_session() as session: query: Query = session.query(UpsertionRecord).filter( - UpsertionRecord.namespace == self.namespace + UpsertionRecord.namespace == self.namespace, ) if after: @@ -494,7 +494,7 @@ class SQLRecordManager(RecordManager): session: AsyncSession async with self._amake_session() as session: query: Query = select(UpsertionRecord.key).filter( # type: ignore[assignment] - UpsertionRecord.namespace == self.namespace + UpsertionRecord.namespace == self.namespace, ) # mypy does not recognize .all() or .filter() @@ -518,7 +518,7 @@ class SQLRecordManager(RecordManager): and_( UpsertionRecord.key.in_(keys), UpsertionRecord.namespace == self.namespace, - ) + ), ) filtered_query.delete() @@ -532,8 +532,8 @@ class SQLRecordManager(RecordManager): and_( UpsertionRecord.key.in_(keys), UpsertionRecord.namespace == self.namespace, - ) - ) + ), + ), ) await session.commit() diff --git a/libs/langchain/langchain/indexes/prompts/entity_extraction.py b/libs/langchain/langchain/indexes/prompts/entity_extraction.py index 416ba13eea9..ea959bc28d3 100644 --- a/libs/langchain/langchain/indexes/prompts/entity_extraction.py +++ b/libs/langchain/langchain/indexes/prompts/entity_extraction.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate _DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """You are an AI assistant reading the transcript of a conversation between an AI and a human. Extract all of the proper nouns from the last line of conversation. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places. @@ -34,7 +33,7 @@ Conversation history (for reference only): Last line of conversation (for extraction): Human: {input} -Output:""" +Output:""" # noqa: E501 ENTITY_EXTRACTION_PROMPT = PromptTemplate( input_variables=["history", "input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE ) diff --git a/libs/langchain/langchain/indexes/prompts/entity_summarization.py b/libs/langchain/langchain/indexes/prompts/entity_summarization.py index aa8ec6ef99f..89044f2e91c 100644 --- a/libs/langchain/langchain/indexes/prompts/entity_summarization.py +++ b/libs/langchain/langchain/indexes/prompts/entity_summarization.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate _DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE = """You are an AI assistant helping a human keep track of facts about relevant people, places, and concepts in their life. Update the summary of the provided entity in the "Entity" section based on the last line of your conversation with the human. If you are writing the summary for the first time, return a single sentence. @@ -17,7 +16,7 @@ Existing summary of {entity}: Last line of conversation: Human: {input} -Updated summary:""" +Updated summary:""" # noqa: E501 ENTITY_SUMMARIZATION_PROMPT = PromptTemplate( input_variables=["entity", "summary", "history", "input"], diff --git a/libs/langchain/langchain/indexes/prompts/knowledge_triplet_extraction.py b/libs/langchain/langchain/indexes/prompts/knowledge_triplet_extraction.py index 47f59d9d115..b78b9098e22 100644 --- a/libs/langchain/langchain/indexes/prompts/knowledge_triplet_extraction.py +++ b/libs/langchain/langchain/indexes/prompts/knowledge_triplet_extraction.py @@ -1,5 +1,3 @@ -# flake8: noqa - from langchain_core.prompts.prompt import PromptTemplate KG_TRIPLE_DELIMITER = "<|>" @@ -25,7 +23,7 @@ _DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = ( "END OF EXAMPLE\n\n" "EXAMPLE\n" "Oh huh. I know Descartes likes to drive antique scooters and play the mandolin.\n" - f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n" + f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n" # noqa: E501 "END OF EXAMPLE\n\n" "EXAMPLE\n" "{text}" diff --git a/libs/langchain/langchain/indexes/vectorstore.py b/libs/langchain/langchain/indexes/vectorstore.py index 5549b33ae72..e8435821769 100644 --- a/libs/langchain/langchain/indexes/vectorstore.py +++ b/libs/langchain/langchain/indexes/vectorstore.py @@ -58,7 +58,9 @@ class VectorStoreIndexWrapper(BaseModel): raise NotImplementedError(msg) retriever_kwargs = retriever_kwargs or {} chain = RetrievalQA.from_chain_type( - llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs + llm, + retriever=self.vectorstore.as_retriever(**retriever_kwargs), + **kwargs, ) return chain.invoke({chain.input_key: question})[chain.output_key] @@ -91,7 +93,9 @@ class VectorStoreIndexWrapper(BaseModel): raise NotImplementedError(msg) retriever_kwargs = retriever_kwargs or {} chain = RetrievalQA.from_chain_type( - llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs + llm, + retriever=self.vectorstore.as_retriever(**retriever_kwargs), + **kwargs, ) return (await chain.ainvoke({chain.input_key: question}))[chain.output_key] @@ -124,7 +128,9 @@ class VectorStoreIndexWrapper(BaseModel): raise NotImplementedError(msg) retriever_kwargs = retriever_kwargs or {} chain = RetrievalQAWithSourcesChain.from_chain_type( - llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs + llm, + retriever=self.vectorstore.as_retriever(**retriever_kwargs), + **kwargs, ) return chain.invoke({chain.question_key: question}) @@ -157,7 +163,9 @@ class VectorStoreIndexWrapper(BaseModel): raise NotImplementedError(msg) retriever_kwargs = retriever_kwargs or {} chain = RetrievalQAWithSourcesChain.from_chain_type( - llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs + llm, + retriever=self.vectorstore.as_retriever(**retriever_kwargs), + **kwargs, ) return await chain.ainvoke({chain.question_key: question}) @@ -174,7 +182,7 @@ def _get_in_memory_vectorstore() -> type[VectorStore]: warnings.warn( "Using InMemoryVectorStore as the default vectorstore." "This memory store won't persist data. You should explicitly" - "specify a vectorstore when using VectorstoreIndexCreator" + "specify a vectorstore when using VectorstoreIndexCreator", ) return InMemoryVectorStore @@ -183,7 +191,7 @@ class VectorstoreIndexCreator(BaseModel): """Logic for creating indexes.""" vectorstore_cls: type[VectorStore] = Field( - default_factory=_get_in_memory_vectorstore + default_factory=_get_in_memory_vectorstore, ) embedding: Embeddings text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter) @@ -219,8 +227,7 @@ class VectorstoreIndexCreator(BaseModel): """ docs = [] for loader in loaders: - async for doc in loader.alazy_load(): - docs.append(doc) + docs.extend([doc async for doc in loader.alazy_load()]) return await self.afrom_documents(docs) def from_documents(self, documents: list[Document]) -> VectorStoreIndexWrapper: @@ -234,12 +241,15 @@ class VectorstoreIndexCreator(BaseModel): """ sub_docs = self.text_splitter.split_documents(documents) vectorstore = self.vectorstore_cls.from_documents( - sub_docs, self.embedding, **self.vectorstore_kwargs + sub_docs, + self.embedding, + **self.vectorstore_kwargs, ) return VectorStoreIndexWrapper(vectorstore=vectorstore) async def afrom_documents( - self, documents: list[Document] + self, + documents: list[Document], ) -> VectorStoreIndexWrapper: """Asynchronously create a vectorstore index from a list of documents. @@ -251,6 +261,8 @@ class VectorstoreIndexCreator(BaseModel): """ sub_docs = self.text_splitter.split_documents(documents) vectorstore = await self.vectorstore_cls.afrom_documents( - sub_docs, self.embedding, **self.vectorstore_kwargs + sub_docs, + self.embedding, + **self.vectorstore_kwargs, ) return VectorStoreIndexWrapper(vectorstore=vectorstore) diff --git a/libs/langchain/langchain/llms/cloudflare_workersai.py b/libs/langchain/langchain/llms/cloudflare_workersai.py index 13b008e0849..a16ed2cd851 100644 --- a/libs/langchain/langchain/llms/cloudflare_workersai.py +++ b/libs/langchain/langchain/llms/cloudflare_workersai.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "CloudflareWorkersAI": "langchain_community.llms.cloudflare_workersai" + "CloudflareWorkersAI": "langchain_community.llms.cloudflare_workersai", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/buffer.py b/libs/langchain/langchain/memory/buffer.py index 1bd8a48b884..a3bd947e595 100644 --- a/libs/langchain/langchain/memory/buffer.py +++ b/libs/langchain/langchain/memory/buffer.py @@ -155,10 +155,12 @@ class ConversationStringBufferMemory(BaseMemory): output_key = self.output_key human = f"{self.human_prefix}: " + inputs[prompt_input_key] ai = f"{self.ai_prefix}: " + outputs[output_key] - self.buffer += "\n" + "\n".join([human, ai]) + self.buffer += f"\n{human}\n{ai}" async def asave_context( - self, inputs: dict[str, Any], outputs: dict[str, str] + self, + inputs: dict[str, Any], + outputs: dict[str, str], ) -> None: """Save context from this conversation to buffer.""" return self.save_context(inputs, outputs) diff --git a/libs/langchain/langchain/memory/chat_memory.py b/libs/langchain/langchain/memory/chat_memory.py index 4a18ce8d44d..ad2f666cff6 100644 --- a/libs/langchain/langchain/memory/chat_memory.py +++ b/libs/langchain/langchain/memory/chat_memory.py @@ -34,14 +34,16 @@ class BaseChatMemory(BaseMemory, ABC): """ chat_memory: BaseChatMessageHistory = Field( - default_factory=InMemoryChatMessageHistory + default_factory=InMemoryChatMessageHistory, ) output_key: Optional[str] = None input_key: Optional[str] = None return_messages: bool = False def _get_input_output( - self, inputs: dict[str, Any], outputs: dict[str, str] + self, + inputs: dict[str, Any], + outputs: dict[str, str], ) -> tuple[str, str]: if self.input_key is None: prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) @@ -55,7 +57,7 @@ class BaseChatMemory(BaseMemory, ABC): warnings.warn( f"'{self.__class__.__name__}' got multiple output keys:" f" {outputs.keys()}. The default 'output' key is being used." - f" If this is not desired, please manually set 'output_key'." + f" If this is not desired, please manually set 'output_key'.", ) else: msg = ( @@ -75,11 +77,13 @@ class BaseChatMemory(BaseMemory, ABC): [ HumanMessage(content=input_str), AIMessage(content=output_str), - ] + ], ) async def asave_context( - self, inputs: dict[str, Any], outputs: dict[str, str] + self, + inputs: dict[str, Any], + outputs: dict[str, str], ) -> None: """Save context from this conversation to buffer.""" input_str, output_str = self._get_input_output(inputs, outputs) @@ -87,7 +91,7 @@ class BaseChatMemory(BaseMemory, ABC): [ HumanMessage(content=input_str), AIMessage(content=output_str), - ] + ], ) def clear(self) -> None: diff --git a/libs/langchain/langchain/memory/chat_message_histories/astradb.py b/libs/langchain/langchain/memory/chat_message_histories/astradb.py index 3895f0633ce..feaccaa057e 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/astradb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/astradb.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "AstraDBChatMessageHistory": "langchain_community.chat_message_histories" + "AstraDBChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/cassandra.py b/libs/langchain/langchain/memory/chat_message_histories/cassandra.py index e0c61cd2344..d473f255318 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/cassandra.py +++ b/libs/langchain/langchain/memory/chat_message_histories/cassandra.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "CassandraChatMessageHistory": "langchain_community.chat_message_histories" + "CassandraChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/cosmos_db.py b/libs/langchain/langchain/memory/chat_message_histories/cosmos_db.py index 6e9c8bf2f72..9e0fb34ccbf 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/cosmos_db.py +++ b/libs/langchain/langchain/memory/chat_message_histories/cosmos_db.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "CosmosDBChatMessageHistory": "langchain_community.chat_message_histories" + "CosmosDBChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py b/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py index e0c5df1a4d1..04564ec23fe 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "DynamoDBChatMessageHistory": "langchain_community.chat_message_histories" + "DynamoDBChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/elasticsearch.py b/libs/langchain/langchain/memory/chat_message_histories/elasticsearch.py index 8845e67a586..c76278ed683 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/elasticsearch.py +++ b/libs/langchain/langchain/memory/chat_message_histories/elasticsearch.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories" + "ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/file.py b/libs/langchain/langchain/memory/chat_message_histories/file.py index 2f9a0a9e68b..388af20534b 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/file.py +++ b/libs/langchain/langchain/memory/chat_message_histories/file.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "FileChatMessageHistory": "langchain_community.chat_message_histories" + "FileChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/firestore.py b/libs/langchain/langchain/memory/chat_message_histories/firestore.py index 0d81d24d554..23387ab41ef 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/firestore.py +++ b/libs/langchain/langchain/memory/chat_message_histories/firestore.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "FirestoreChatMessageHistory": "langchain_community.chat_message_histories" + "FirestoreChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/momento.py b/libs/langchain/langchain/memory/chat_message_histories/momento.py index 65d7cce3031..dbef9d1f535 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/momento.py +++ b/libs/langchain/langchain/memory/chat_message_histories/momento.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "MomentoChatMessageHistory": "langchain_community.chat_message_histories" + "MomentoChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/mongodb.py b/libs/langchain/langchain/memory/chat_message_histories/mongodb.py index f0f7db2c462..fb378731900 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/mongodb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/mongodb.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "MongoDBChatMessageHistory": "langchain_community.chat_message_histories" + "MongoDBChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/neo4j.py b/libs/langchain/langchain/memory/chat_message_histories/neo4j.py index ac60cd7c349..c3c143bf556 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/neo4j.py +++ b/libs/langchain/langchain/memory/chat_message_histories/neo4j.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "Neo4jChatMessageHistory": "langchain_community.chat_message_histories" + "Neo4jChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/postgres.py b/libs/langchain/langchain/memory/chat_message_histories/postgres.py index eb05b7dba47..a98b499c2fc 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/postgres.py +++ b/libs/langchain/langchain/memory/chat_message_histories/postgres.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "PostgresChatMessageHistory": "langchain_community.chat_message_histories" + "PostgresChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/redis.py b/libs/langchain/langchain/memory/chat_message_histories/redis.py index a6dc08e2606..b312ec7ed04 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/redis.py +++ b/libs/langchain/langchain/memory/chat_message_histories/redis.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "RedisChatMessageHistory": "langchain_community.chat_message_histories" + "RedisChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py b/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py index 64116d5b129..17b589a4eed 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/rocksetdb.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "RocksetChatMessageHistory": "langchain_community.chat_message_histories" + "RocksetChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/singlestoredb.py b/libs/langchain/langchain/memory/chat_message_histories/singlestoredb.py index dc6b795243f..0c7531d98c6 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/singlestoredb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/singlestoredb.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "SingleStoreDBChatMessageHistory": "langchain_community.chat_message_histories" + "SingleStoreDBChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/streamlit.py b/libs/langchain/langchain/memory/chat_message_histories/streamlit.py index b1a74824047..46a9be6f970 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/streamlit.py +++ b/libs/langchain/langchain/memory/chat_message_histories/streamlit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "StreamlitChatMessageHistory": "langchain_community.chat_message_histories" + "StreamlitChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/upstash_redis.py b/libs/langchain/langchain/memory/chat_message_histories/upstash_redis.py index af4599f7bd5..37f2e1a53e2 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/upstash_redis.py +++ b/libs/langchain/langchain/memory/chat_message_histories/upstash_redis.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "UpstashRedisChatMessageHistory": "langchain_community.chat_message_histories" + "UpstashRedisChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/xata.py b/libs/langchain/langchain/memory/chat_message_histories/xata.py index 43cfe67266a..b9e76a0489f 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/xata.py +++ b/libs/langchain/langchain/memory/chat_message_histories/xata.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "XataChatMessageHistory": "langchain_community.chat_message_histories" + "XataChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/chat_message_histories/zep.py b/libs/langchain/langchain/memory/chat_message_histories/zep.py index 9b63e0fc22d..a0f48f16e0c 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/zep.py +++ b/libs/langchain/langchain/memory/chat_message_histories/zep.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "ZepChatMessageHistory": "langchain_community.chat_message_histories" + "ZepChatMessageHistory": "langchain_community.chat_message_histories", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/memory/combined.py b/libs/langchain/langchain/memory/combined.py index ca36237807d..5f21ddbf7a9 100644 --- a/libs/langchain/langchain/memory/combined.py +++ b/libs/langchain/langchain/memory/combined.py @@ -16,7 +16,8 @@ class CombinedMemory(BaseMemory): @field_validator("memories") @classmethod def check_repeated_memory_variable( - cls, value: list[BaseMemory] + cls, + value: list[BaseMemory], ) -> list[BaseMemory]: all_variables: set[str] = set() for val in value: @@ -40,7 +41,7 @@ class CombinedMemory(BaseMemory): warnings.warn( "When using CombinedMemory, " "input keys should be so the input is known. " - f" Was not set on {val}" + f" Was not set on {val}", ) return value diff --git a/libs/langchain/langchain/memory/entity.py b/libs/langchain/langchain/memory/entity.py index 6eb922c79c0..8e35dd9c0e0 100644 --- a/libs/langchain/langchain/memory/entity.py +++ b/libs/langchain/langchain/memory/entity.py @@ -154,7 +154,7 @@ class UpstashRedisEntityStore(BaseEntityStore): return self.delete(key) self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl) logger.debug( - f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}" + f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}", ) return None @@ -167,7 +167,8 @@ class UpstashRedisEntityStore(BaseEntityStore): def clear(self) -> None: def scan_and_delete(cursor: int) -> int: cursor, keys_to_delete = self.redis_client.scan( - cursor, f"{self.full_key_prefix}:*" + cursor, + f"{self.full_key_prefix}:*", ) self.redis_client.delete(*keys_to_delete) return cursor @@ -256,7 +257,7 @@ class RedisEntityStore(BaseEntityStore): return self.delete(key) self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl) logger.debug( - f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}" + f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}", ) return None @@ -274,7 +275,8 @@ class RedisEntityStore(BaseEntityStore): yield batch for keybatch in batched( - self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500 + self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), + 500, ): self.redis_client.delete(*keybatch) diff --git a/libs/langchain/langchain/memory/prompt.py b/libs/langchain/langchain/memory/prompt.py index c16e8e24931..f28de16f825 100644 --- a/libs/langchain/langchain/memory/prompt.py +++ b/libs/langchain/langchain/memory/prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate _DEFAULT_ENTITY_MEMORY_CONVERSATION_TEMPLATE = """You are an assistant to a human, powered by a large language model trained by OpenAI. @@ -16,7 +15,7 @@ Current conversation: {history} Last line: Human: {input} -You:""" +You:""" # noqa: E501 ENTITY_MEMORY_CONVERSATION_TEMPLATE = PromptTemplate( input_variables=["entities", "history", "input"], @@ -43,7 +42,7 @@ Current summary: New lines of conversation: {new_lines} -New summary:""" +New summary:""" # noqa: E501 SUMMARY_PROMPT = PromptTemplate( input_variables=["summary", "new_lines"], template=_DEFAULT_SUMMARIZER_TEMPLATE ) @@ -81,7 +80,7 @@ Conversation history (for reference only): Last line of conversation (for extraction): Human: {input} -Output:""" +Output:""" # noqa: E501 ENTITY_EXTRACTION_PROMPT = PromptTemplate( input_variables=["history", "input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE ) @@ -102,7 +101,7 @@ Existing summary of {entity}: Last line of conversation: Human: {input} -Updated summary:""" +Updated summary:""" # noqa: E501 ENTITY_SUMMARIZATION_PROMPT = PromptTemplate( input_variables=["entity", "summary", "history", "input"], @@ -128,7 +127,7 @@ _DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = ( "Person #1: It's a secret military base in Nevada.\n" "AI: What do you know about Nevada?\n" "Last line of conversation:\n" - "Person #1: It's a state in the US. It's also the number 1 producer of gold in the US.\n\n" + "Person #1: It's a state in the US. It's also the number 1 producer of gold in the US.\n\n" # noqa: E501 f"Output: (Nevada, is a, state){KG_TRIPLE_DELIMITER}(Nevada, is in, US)" f"{KG_TRIPLE_DELIMITER}(Nevada, is the number 1 producer of, gold)\n" "END OF EXAMPLE\n\n" @@ -145,12 +144,12 @@ _DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = ( "EXAMPLE\n" "Conversation history:\n" "Person #1: What do you know about Descartes?\n" - "AI: Descartes was a French philosopher, mathematician, and scientist who lived in the 17th century.\n" - "Person #1: The Descartes I'm referring to is a standup comedian and interior designer from Montreal.\n" - "AI: Oh yes, He is a comedian and an interior designer. He has been in the industry for 30 years. His favorite food is baked bean pie.\n" + "AI: Descartes was a French philosopher, mathematician, and scientist who lived in the 17th century.\n" # noqa: E501 + "Person #1: The Descartes I'm referring to is a standup comedian and interior designer from Montreal.\n" # noqa: E501 + "AI: Oh yes, He is a comedian and an interior designer. He has been in the industry for 30 years. His favorite food is baked bean pie.\n" # noqa: E501 "Last line of conversation:\n" - "Person #1: Oh huh. I know Descartes likes to drive antique scooters and play the mandolin.\n" - f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n" + "Person #1: Oh huh. I know Descartes likes to drive antique scooters and play the mandolin.\n" # noqa: E501 + f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n" # noqa: E501 "END OF EXAMPLE\n\n" "Conversation history (for reference only):\n" "{history}" diff --git a/libs/langchain/langchain/memory/summary.py b/libs/langchain/langchain/memory/summary.py index 98825a10c37..659e49ddd3a 100644 --- a/libs/langchain/langchain/memory/summary.py +++ b/libs/langchain/langchain/memory/summary.py @@ -35,7 +35,9 @@ class SummarizerMixin(BaseModel): summary_message_cls: type[BaseMessage] = SystemMessage def predict_new_summary( - self, messages: list[BaseMessage], existing_summary: str + self, + messages: list[BaseMessage], + existing_summary: str, ) -> str: new_lines = get_buffer_string( messages, @@ -47,7 +49,9 @@ class SummarizerMixin(BaseModel): return chain.predict(summary=existing_summary, new_lines=new_lines) async def apredict_new_summary( - self, messages: list[BaseMessage], existing_summary: str + self, + messages: list[BaseMessage], + existing_summary: str, ) -> str: new_lines = get_buffer_string( messages, @@ -90,7 +94,8 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin): obj = cls(llm=llm, chat_memory=chat_memory, **kwargs) for i in range(0, len(obj.chat_memory.messages), summarize_step): obj.buffer = obj.predict_new_summary( - obj.chat_memory.messages[i : i + summarize_step], obj.buffer + obj.chat_memory.messages[i : i + summarize_step], + obj.buffer, ) return obj @@ -127,7 +132,8 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin): """Save context from this conversation to buffer.""" super().save_context(inputs, outputs) self.buffer = self.predict_new_summary( - self.chat_memory.messages[-2:], self.buffer + self.chat_memory.messages[-2:], + self.buffer, ) def clear(self) -> None: diff --git a/libs/langchain/langchain/memory/summary_buffer.py b/libs/langchain/langchain/memory/summary_buffer.py index 5dce678f22a..692ea0d81eb 100644 --- a/libs/langchain/langchain/memory/summary_buffer.py +++ b/libs/langchain/langchain/memory/summary_buffer.py @@ -51,14 +51,16 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): buffer = self.chat_memory.messages if self.moving_summary_buffer != "": first_messages: list[BaseMessage] = [ - self.summary_message_cls(content=self.moving_summary_buffer) + self.summary_message_cls(content=self.moving_summary_buffer), ] buffer = first_messages + buffer if self.return_messages: final_buffer: Any = buffer else: final_buffer = get_buffer_string( - buffer, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix + buffer, + human_prefix=self.human_prefix, + ai_prefix=self.ai_prefix, ) return {self.memory_key: final_buffer} @@ -67,14 +69,16 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): buffer = await self.chat_memory.aget_messages() if self.moving_summary_buffer != "": first_messages: list[BaseMessage] = [ - self.summary_message_cls(content=self.moving_summary_buffer) + self.summary_message_cls(content=self.moving_summary_buffer), ] buffer = first_messages + buffer if self.return_messages: final_buffer: Any = buffer else: final_buffer = get_buffer_string( - buffer, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix + buffer, + human_prefix=self.human_prefix, + ai_prefix=self.ai_prefix, ) return {self.memory_key: final_buffer} @@ -97,7 +101,9 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): self.prune() async def asave_context( - self, inputs: dict[str, Any], outputs: dict[str, str] + self, + inputs: dict[str, Any], + outputs: dict[str, str], ) -> None: """Asynchronously save context from this conversation to buffer.""" await super().asave_context(inputs, outputs) @@ -113,7 +119,8 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): pruned_memory.append(buffer.pop(0)) curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) self.moving_summary_buffer = self.predict_new_summary( - pruned_memory, self.moving_summary_buffer + pruned_memory, + self.moving_summary_buffer, ) async def aprune(self) -> None: @@ -126,7 +133,8 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): pruned_memory.append(buffer.pop(0)) curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) self.moving_summary_buffer = await self.apredict_new_summary( - pruned_memory, self.moving_summary_buffer + pruned_memory, + self.moving_summary_buffer, ) def clear(self) -> None: diff --git a/libs/langchain/langchain/memory/vectorstore.py b/libs/langchain/langchain/memory/vectorstore.py index d99df56d4e5..6d54fd0ed76 100644 --- a/libs/langchain/langchain/memory/vectorstore.py +++ b/libs/langchain/langchain/memory/vectorstore.py @@ -52,7 +52,8 @@ class VectorStoreRetrieverMemory(BaseMemory): return self.input_key def _documents_to_memory_variables( - self, docs: list[Document] + self, + docs: list[Document], ) -> dict[str, Union[list[Document], str]]: result: Union[list[Document], str] if not self.return_docs: @@ -62,7 +63,8 @@ class VectorStoreRetrieverMemory(BaseMemory): return {self.memory_key: result} def load_memory_variables( - self, inputs: dict[str, Any] + self, + inputs: dict[str, Any], ) -> dict[str, Union[list[Document], str]]: """Return history buffer.""" input_key = self._get_prompt_input_key(inputs) @@ -71,7 +73,8 @@ class VectorStoreRetrieverMemory(BaseMemory): return self._documents_to_memory_variables(docs) async def aload_memory_variables( - self, inputs: dict[str, Any] + self, + inputs: dict[str, Any], ) -> dict[str, Union[list[Document], str]]: """Return history buffer.""" input_key = self._get_prompt_input_key(inputs) @@ -80,7 +83,9 @@ class VectorStoreRetrieverMemory(BaseMemory): return self._documents_to_memory_variables(docs) def _form_documents( - self, inputs: dict[str, Any], outputs: dict[str, str] + self, + inputs: dict[str, Any], + outputs: dict[str, str], ) -> list[Document]: """Format context from this conversation to buffer.""" # Each document should only include the current turn, not the chat history @@ -100,7 +105,9 @@ class VectorStoreRetrieverMemory(BaseMemory): self.retriever.add_documents(documents) async def asave_context( - self, inputs: dict[str, Any], outputs: dict[str, str] + self, + inputs: dict[str, Any], + outputs: dict[str, str], ) -> None: """Save context from this conversation to buffer.""" documents = self._form_documents(inputs, outputs) diff --git a/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py b/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py index ddefc55f7b7..480e037dd13 100644 --- a/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py +++ b/libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py @@ -131,13 +131,13 @@ class ConversationVectorStoreTokenBufferMemory(ConversationTokenBufferMemory): previous_history = "" current_history = super().load_memory_variables(inputs) template = SystemMessagePromptTemplate.from_template( - self.previous_history_template + self.previous_history_template, ) messages = [ template.format( previous_history=previous_history, current_time=datetime.now().astimezone().strftime(TIMESTAMP_FORMAT), - ) + ), ] messages.extend(current_history[self.memory_key]) return {self.memory_key: messages} diff --git a/libs/langchain/langchain/model_laboratory.py b/libs/langchain/langchain/model_laboratory.py index d4552ab86d0..f90646ceb58 100644 --- a/libs/langchain/langchain/model_laboratory.py +++ b/libs/langchain/langchain/model_laboratory.py @@ -61,7 +61,9 @@ class ModelLaboratory: @classmethod def from_llms( - cls, llms: list[BaseLLM], prompt: Optional[PromptTemplate] = None + cls, + llms: list[BaseLLM], + prompt: Optional[PromptTemplate] = None, ) -> ModelLaboratory: """Initialize the ModelLaboratory with LLMs and an optional prompt. diff --git a/libs/langchain/langchain/output_parsers/__init__.py b/libs/langchain/langchain/output_parsers/__init__.py index bff4ddb8f20..1683a9d1423 100644 --- a/libs/langchain/langchain/output_parsers/__init__.py +++ b/libs/langchain/langchain/output_parsers/__init__.py @@ -49,7 +49,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "GuardrailsOutputParser": "langchain_community.output_parsers.rail_parser" + "GuardrailsOutputParser": "langchain_community.output_parsers.rail_parser", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/output_parsers/datetime.py b/libs/langchain/langchain/output_parsers/datetime.py index 3ad2e02aecc..42d58e0eca2 100644 --- a/libs/langchain/langchain/output_parsers/datetime.py +++ b/libs/langchain/langchain/output_parsers/datetime.py @@ -22,7 +22,7 @@ class DatetimeOutputParser(BaseOutputParser[datetime]): "2023-07-04T14:30:00.000000Z", "1999-12-31T23:59:59.999999Z", "2025-01-01T00:00:00.000000Z", - ] + ], ) else: try: @@ -32,7 +32,7 @@ class DatetimeOutputParser(BaseOutputParser[datetime]): now.strftime(self.format), (now.replace(year=now.year - 1)).strftime(self.format), (now - timedelta(days=1)).strftime(self.format), - ] + ], ) except ValueError: # Fallback if the format is very unusual diff --git a/libs/langchain/langchain/output_parsers/fix.py b/libs/langchain/langchain/output_parsers/fix.py index 34d214be51a..62b1bbf7e01 100644 --- a/libs/langchain/langchain/output_parsers/fix.py +++ b/libs/langchain/langchain/output_parsers/fix.py @@ -85,7 +85,7 @@ class OutputFixingParser(BaseOutputParser[T]): "instructions": self.parser.get_format_instructions(), "completion": completion, "error": repr(e), - } + }, ) except (NotImplementedError, AttributeError): # Case: self.parser does not have get_format_instructions @@ -93,7 +93,7 @@ class OutputFixingParser(BaseOutputParser[T]): { "completion": completion, "error": repr(e), - } + }, ) msg = "Failed to parse" @@ -122,7 +122,7 @@ class OutputFixingParser(BaseOutputParser[T]): "instructions": self.parser.get_format_instructions(), "completion": completion, "error": repr(e), - } + }, ) except (NotImplementedError, AttributeError): # Case: self.parser does not have get_format_instructions @@ -130,7 +130,7 @@ class OutputFixingParser(BaseOutputParser[T]): { "completion": completion, "error": repr(e), - } + }, ) msg = "Failed to parse" diff --git a/libs/langchain/langchain/output_parsers/format_instructions.py b/libs/langchain/langchain/output_parsers/format_instructions.py index d9e892f206c..229f0a64198 100644 --- a/libs/langchain/langchain/output_parsers/format_instructions.py +++ b/libs/langchain/langchain/output_parsers/format_instructions.py @@ -1,12 +1,10 @@ -# flake8: noqa - STRUCTURED_FORMAT_INSTRUCTIONS = """The output should be a markdown code snippet formatted in the following schema, including the leading and trailing "```json" and "```": ```json {{ {format} }} -```""" +```""" # noqa: E501 STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS = """ ```json @@ -24,7 +22,7 @@ the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. Here is the output schema: ``` {schema} -```""" +```""" # noqa: E501 YAML_FORMAT_INSTRUCTIONS = """The output should be formatted as a YAML instance that conforms to the given JSON schema below. @@ -49,14 +47,14 @@ YAML_FORMAT_INSTRUCTIONS = """The output should be formatted as a YAML instance ``` habit: Using disposable water bottles for daily hydration. sustainable_alternative: Switch to a reusable water bottle to reduce plastic waste and decrease your environmental footprint. -``` +``` -Please follow the standard YAML formatting conventions with an indent of 2 spaces and make sure that the data types adhere strictly to the following JSON schema: +Please follow the standard YAML formatting conventions with an indent of 2 spaces and make sure that the data types adhere strictly to the following JSON schema: ``` {schema} ``` -Make sure to always enclose the YAML output in triple backticks (```). Please do not add anything other than valid YAML output!""" +Make sure to always enclose the YAML output in triple backticks (```). Please do not add anything other than valid YAML output!""" # noqa: E501 PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS = """The output should be formatted as a string as the operation, followed by a colon, followed by the column or row to be queried on, followed by optional array parameters. @@ -78,4 +76,4 @@ Here are the possible columns: ``` {columns} ``` -""" +""" # noqa: E501 diff --git a/libs/langchain/langchain/output_parsers/pandas_dataframe.py b/libs/langchain/langchain/output_parsers/pandas_dataframe.py index e1e94c338be..b1eaedb318c 100644 --- a/libs/langchain/langchain/output_parsers/pandas_dataframe.py +++ b/libs/langchain/langchain/output_parsers/pandas_dataframe.py @@ -32,7 +32,9 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]): raise TypeError(msg) def parse_array( - self, array: str, original_request_params: str + self, + array: str, + original_request_params: str, ) -> tuple[list[Union[int, str]], str]: parsed_array: list[Union[int, str]] = [] @@ -90,7 +92,8 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]): array_exists = re.search(r"(\[.*?\])", request_params) if array_exists: parsed_array, stripped_request_params = self.parse_array( - array_exists.group(1), request_params + array_exists.group(1), + request_params, ) if request_type == "column": filtered_df = self.dataframe[ @@ -121,7 +124,8 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]): self.dataframe.index.isin(parsed_array) ] result[request_type] = getattr( - filtered_df[stripped_request_params], request_type + filtered_df[stripped_request_params], + request_type, )() else: if request_type == "column": @@ -130,7 +134,8 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]): result[request_params] = self.dataframe.iloc[int(request_params)] else: result[request_type] = getattr( - self.dataframe[request_params], request_type + self.dataframe[request_params], + request_type, )() except (AttributeError, IndexError, KeyError): if request_type not in {"column", "row"}: @@ -148,5 +153,5 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]): def get_format_instructions(self) -> str: return PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS.format( - columns=", ".join(self.dataframe.columns) + columns=", ".join(self.dataframe.columns), ) diff --git a/libs/langchain/langchain/output_parsers/prompts.py b/libs/langchain/langchain/output_parsers/prompts.py index dd06a70c58b..e2e34adce63 100644 --- a/libs/langchain/langchain/output_parsers/prompts.py +++ b/libs/langchain/langchain/output_parsers/prompts.py @@ -1,4 +1,3 @@ -# flake8: noqa from langchain_core.prompts.prompt import PromptTemplate NAIVE_FIX = """Instructions: @@ -16,7 +15,7 @@ Error: {error} -------------- -Please try again. Please only respond with an answer that satisfies the constraints laid out in the Instructions:""" +Please try again. Please only respond with an answer that satisfies the constraints laid out in the Instructions:""" # noqa: E501 NAIVE_FIX_PROMPT = PromptTemplate.from_template(NAIVE_FIX) diff --git a/libs/langchain/langchain/output_parsers/rail_parser.py b/libs/langchain/langchain/output_parsers/rail_parser.py index 79637931982..daaf59dc7c7 100644 --- a/libs/langchain/langchain/output_parsers/rail_parser.py +++ b/libs/langchain/langchain/output_parsers/rail_parser.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "GuardrailsOutputParser": "langchain_community.output_parsers.rail_parser" + "GuardrailsOutputParser": "langchain_community.output_parsers.rail_parser", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/output_parsers/retry.py b/libs/langchain/langchain/output_parsers/retry.py index f9eb15dc94f..d0aed891210 100644 --- a/libs/langchain/langchain/output_parsers/retry.py +++ b/libs/langchain/langchain/output_parsers/retry.py @@ -30,7 +30,7 @@ Please try again:""" NAIVE_RETRY_PROMPT = PromptTemplate.from_template(NAIVE_COMPLETION_RETRY) NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template( - NAIVE_COMPLETION_RETRY_WITH_ERROR + NAIVE_COMPLETION_RETRY_WITH_ERROR, ) T = TypeVar("T") @@ -118,7 +118,7 @@ class RetryOutputParser(BaseOutputParser[T]): { "prompt": prompt_value.to_string(), "completion": completion, - } + }, ) msg = "Failed to parse" @@ -154,7 +154,7 @@ class RetryOutputParser(BaseOutputParser[T]): { "prompt": prompt_value.to_string(), "completion": completion, - } + }, ) msg = "Failed to parse" @@ -191,7 +191,8 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): # Should be an LLMChain but we want to avoid top-level imports from langchain.chains retry_chain: Annotated[ Union[ - RunnableSerializable[RetryWithErrorOutputParserRetryChainInput, str], Any + RunnableSerializable[RetryWithErrorOutputParserRetryChainInput, str], + Any, ], SkipValidation(), ] @@ -245,7 +246,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): "completion": completion, "prompt": prompt_value.to_string(), "error": repr(e), - } + }, ) msg = "Failed to parse" @@ -273,7 +274,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): "prompt": prompt_value.to_string(), "completion": completion, "error": repr(e), - } + }, ) msg = "Failed to parse" diff --git a/libs/langchain/langchain/output_parsers/structured.py b/libs/langchain/langchain/output_parsers/structured.py index 2cde68ded84..5210a8bb4da 100644 --- a/libs/langchain/langchain/output_parsers/structured.py +++ b/libs/langchain/langchain/output_parsers/structured.py @@ -27,7 +27,9 @@ class ResponseSchema(BaseModel): def _get_sub_string(schema: ResponseSchema) -> str: return line_template.format( - name=schema.name, description=schema.description, type=schema.type + name=schema.name, + description=schema.description, + type=schema.type, ) @@ -39,7 +41,8 @@ class StructuredOutputParser(BaseOutputParser[dict[str, Any]]): @classmethod def from_response_schemas( - cls, response_schemas: list[ResponseSchema] + cls, + response_schemas: list[ResponseSchema], ) -> StructuredOutputParser: return cls(response_schemas=response_schemas) @@ -88,7 +91,7 @@ class StructuredOutputParser(BaseOutputParser[dict[str, Any]]): will be returned, without the introducing text. Defaults to False. """ schema_str = "\n".join( - [_get_sub_string(schema) for schema in self.response_schemas] + [_get_sub_string(schema) for schema in self.response_schemas], ) if only_json: return STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS.format(format=schema_str) diff --git a/libs/langchain/langchain/output_parsers/yaml.py b/libs/langchain/langchain/output_parsers/yaml.py index 4d86599ee71..3dc353c964f 100644 --- a/libs/langchain/langchain/output_parsers/yaml.py +++ b/libs/langchain/langchain/output_parsers/yaml.py @@ -18,7 +18,8 @@ class YamlOutputParser(BaseOutputParser[T]): pydantic_object: type[T] """The pydantic model to parse.""" pattern: re.Pattern = re.compile( - r"^```(?:ya?ml)?(?P[^`]*)", re.MULTILINE | re.DOTALL + r"^```(?:ya?ml)?(?P[^`]*)", + re.MULTILINE | re.DOTALL, ) """Regex pattern to match yaml code blocks within triple backticks with optional yaml or yml prefix.""" diff --git a/libs/langchain/langchain/prompts/__init__.py b/libs/langchain/langchain/prompts/__init__.py index 7a4ae58c78d..5c677c7fee8 100644 --- a/libs/langchain/langchain/prompts/__init__.py +++ b/libs/langchain/langchain/prompts/__init__.py @@ -65,7 +65,9 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. MODULE_LOOKUP = { - "NGramOverlapExampleSelector": "langchain_community.example_selectors.ngram_overlap" + "NGramOverlapExampleSelector": ( + "langchain_community.example_selectors.ngram_overlap" + ), } _import_attribute = create_importer(__file__, module_lookup=MODULE_LOOKUP) diff --git a/libs/langchain/langchain/prompts/example_selector/__init__.py b/libs/langchain/langchain/prompts/example_selector/__init__.py index 8e56a21b368..375315b7af9 100644 --- a/libs/langchain/langchain/prompts/example_selector/__init__.py +++ b/libs/langchain/langchain/prompts/example_selector/__init__.py @@ -21,7 +21,9 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUPS = { - "NGramOverlapExampleSelector": "langchain_community.example_selectors.ngram_overlap" + "NGramOverlapExampleSelector": ( + "langchain_community.example_selectors.ngram_overlap" + ), } _import_attribute = create_importer(__file__, deprecated_lookups=DEPRECATED_LOOKUPS) diff --git a/libs/langchain/langchain/retrievers/contextual_compression.py b/libs/langchain/langchain/retrievers/contextual_compression.py index 313cc8880c9..0891bb913da 100644 --- a/libs/langchain/langchain/retrievers/contextual_compression.py +++ b/libs/langchain/langchain/retrievers/contextual_compression.py @@ -38,11 +38,15 @@ class ContextualCompressionRetriever(BaseRetriever): Sequence of relevant documents """ docs = self.base_retriever.invoke( - query, config={"callbacks": run_manager.get_child()}, **kwargs + query, + config={"callbacks": run_manager.get_child()}, + **kwargs, ) if docs: compressed_docs = self.base_compressor.compress_documents( - docs, query, callbacks=run_manager.get_child() + docs, + query, + callbacks=run_manager.get_child(), ) return list(compressed_docs) return [] @@ -63,11 +67,15 @@ class ContextualCompressionRetriever(BaseRetriever): List of relevant documents """ docs = await self.base_retriever.ainvoke( - query, config={"callbacks": run_manager.get_child()}, **kwargs + query, + config={"callbacks": run_manager.get_child()}, + **kwargs, ) if docs: compressed_docs = await self.base_compressor.acompress_documents( - docs, query, callbacks=run_manager.get_child() + docs, + query, + callbacks=run_manager.get_child(), ) return list(compressed_docs) return [] diff --git a/libs/langchain/langchain/retrievers/document_compressors/base.py b/libs/langchain/langchain/retrievers/document_compressors/base.py index 3613dba33a1..f007e1c7d2c 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/base.py +++ b/libs/langchain/langchain/retrievers/document_compressors/base.py @@ -32,13 +32,15 @@ class DocumentCompressorPipeline(BaseDocumentCompressor): if isinstance(_transformer, BaseDocumentCompressor): accepts_callbacks = ( signature(_transformer.compress_documents).parameters.get( - "callbacks" + "callbacks", ) is not None ) if accepts_callbacks: documents = _transformer.compress_documents( - documents, query, callbacks=callbacks + documents, + query, + callbacks=callbacks, ) else: documents = _transformer.compress_documents(documents, query) @@ -60,13 +62,15 @@ class DocumentCompressorPipeline(BaseDocumentCompressor): if isinstance(_transformer, BaseDocumentCompressor): accepts_callbacks = ( signature(_transformer.acompress_documents).parameters.get( - "callbacks" + "callbacks", ) is not None ) if accepts_callbacks: documents = await _transformer.acompress_documents( - documents, query, callbacks=callbacks + documents, + query, + callbacks=callbacks, ) else: documents = await _transformer.acompress_documents(documents, query) diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py index d1409cd3ae2..cc156007e08 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py @@ -80,7 +80,7 @@ class LLMChainExtractor(BaseDocumentCompressor): if len(output) == 0: continue compressed_docs.append( - Document(page_content=cast(str, output), metadata=doc.metadata) + Document(page_content=cast(str, output), metadata=doc.metadata), ) return compressed_docs @@ -98,7 +98,7 @@ class LLMChainExtractor(BaseDocumentCompressor): if len(outputs[i]) == 0: continue compressed_docs.append( - Document(page_content=outputs[i], metadata=doc.metadata) + Document(page_content=outputs[i], metadata=doc.metadata), ) return compressed_docs diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_extract_prompt.py b/libs/langchain/langchain/retrievers/document_compressors/chain_extract_prompt.py index c27b8770cb4..9566513e950 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_extract_prompt.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_extract_prompt.py @@ -1,5 +1,4 @@ -# flake8: noqa -prompt_template = """Given the following question and context, extract any part of the context *AS IS* that is relevant to answer the question. If none of the context is relevant return {no_output_str}. +prompt_template = """Given the following question and context, extract any part of the context *AS IS* that is relevant to answer the question. If none of the context is relevant return {no_output_str}. Remember, *DO NOT* edit the extracted parts of the context. @@ -8,4 +7,4 @@ Remember, *DO NOT* edit the extracted parts of the context. >>> {{context}} >>> -Extracted relevant parts:""" +Extracted relevant parts:""" # noqa: E501 diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py index 1b6afac32fc..b5c6f9d1879 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py @@ -58,7 +58,8 @@ class LLMChainFilter(BaseDocumentCompressor): config = RunnableConfig(callbacks=callbacks) outputs = zip( self.llm_chain.batch( - [self.get_input(query, doc) for doc in documents], config=config + [self.get_input(query, doc) for doc in documents], + config=config, ), documents, ) @@ -89,7 +90,8 @@ class LLMChainFilter(BaseDocumentCompressor): config = RunnableConfig(callbacks=callbacks) outputs = zip( await self.llm_chain.abatch( - [self.get_input(query, doc) for doc in documents], config=config + [self.get_input(query, doc) for doc in documents], + config=config, ), documents, ) diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_filter_prompt.py b/libs/langchain/langchain/retrievers/document_compressors/chain_filter_prompt.py index 5376dfa2a18..e7ab5b8cb07 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_filter_prompt.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_filter_prompt.py @@ -1,4 +1,3 @@ -# flake8: noqa prompt_template = """Given the following question and context, return YES if the context is relevant to the question and NO if it isn't. > Question: {question} @@ -6,4 +5,4 @@ prompt_template = """Given the following question and context, return YES if the >>> {context} >>> -> Relevant (YES / NO):""" +> Relevant (YES / NO):""" # noqa: E501 diff --git a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py index f878cb5096e..4b9c9eb2a20 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py @@ -12,7 +12,9 @@ from pydantic import ConfigDict, model_validator @deprecated( - since="0.0.30", removal="1.0", alternative_import="langchain_cohere.CohereRerank" + since="0.0.30", + removal="1.0", + alternative_import="langchain_cohere.CohereRerank", ) class CohereRerank(BaseDocumentCompressor): """Document compressor that uses `Cohere Rerank API`.""" @@ -48,7 +50,9 @@ class CohereRerank(BaseDocumentCompressor): ) raise ImportError(msg) cohere_api_key = get_from_dict_or_env( - values, "cohere_api_key", "COHERE_API_KEY" + values, + "cohere_api_key", + "COHERE_API_KEY", ) client_name = values.get("user_agent", "langchain") values["client"] = cohere.Client(cohere_api_key, client_name=client_name) @@ -89,15 +93,10 @@ class CohereRerank(BaseDocumentCompressor): ) if hasattr(results, "results"): results = getattr(results, "results") - result_dicts = [] - for res in results: - result_dicts.append( - { - "index": res.index, - "relevance_score": res.relevance_score, - } - ) - return result_dicts + return [ + {"index": res.index, "relevance_score": res.relevance_score} + for res in results + ] def compress_documents( self, diff --git a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py index 1826d2acdb5..d02ed3fc701 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py @@ -76,7 +76,8 @@ class EmbeddingsFilter(BaseDocumentCompressor): raise ImportError(msg) from e stateful_documents = get_stateful_documents(documents) embedded_documents = _get_embeddings_from_stateful_docs( - self.embeddings, stateful_documents + self.embeddings, + stateful_documents, ) embedded_query = self.embeddings.embed_query(query) similarity = self.similarity_fn([embedded_query], embedded_documents)[0] @@ -85,7 +86,7 @@ class EmbeddingsFilter(BaseDocumentCompressor): included_idxs = np.argsort(similarity)[::-1][: self.k] if self.similarity_threshold is not None: similar_enough = np.where( - similarity[included_idxs] > self.similarity_threshold + similarity[included_idxs] > self.similarity_threshold, ) included_idxs = included_idxs[similar_enough] for i in included_idxs: @@ -118,7 +119,8 @@ class EmbeddingsFilter(BaseDocumentCompressor): raise ImportError(msg) from e stateful_documents = get_stateful_documents(documents) embedded_documents = await _aget_embeddings_from_stateful_docs( - self.embeddings, stateful_documents + self.embeddings, + stateful_documents, ) embedded_query = await self.embeddings.aembed_query(query) similarity = self.similarity_fn([embedded_query], embedded_documents)[0] @@ -127,7 +129,7 @@ class EmbeddingsFilter(BaseDocumentCompressor): included_idxs = np.argsort(similarity)[::-1][: self.k] if self.similarity_threshold is not None: similar_enough = np.where( - similarity[included_idxs] > self.similarity_threshold + similarity[included_idxs] > self.similarity_threshold, ) included_idxs = included_idxs[similar_enough] for i in included_idxs: diff --git a/libs/langchain/langchain/retrievers/document_compressors/flashrank_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/flashrank_rerank.py index f2196fa6250..9640098ecc6 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/flashrank_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/flashrank_rerank.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank" + "FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py index f0d9859363e..61da3a741aa 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/listwise_rerank.py @@ -92,7 +92,8 @@ class LLMListwiseRerank(BaseDocumentCompressor): ) -> Sequence[Document]: """Filter down documents based on their relevance to the query.""" results = self.reranker.invoke( - {"documents": documents, "query": query}, config={"callbacks": callbacks} + {"documents": documents, "query": query}, + config={"callbacks": callbacks}, ) return results[: self.top_n] @@ -138,6 +139,6 @@ class LLMListwiseRerank(BaseDocumentCompressor): reranker = RunnablePassthrough.assign( ranking=RunnableLambda(_get_prompt_input) | _prompt - | llm.with_structured_output(RankDocuments) + | llm.with_structured_output(RankDocuments), ) | RunnableLambda(_parse_ranking) return cls(reranker=reranker, **kwargs) diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py index b6c2e7d16b9..27d39ab0b50 100644 --- a/libs/langchain/langchain/retrievers/ensemble.py +++ b/libs/langchain/langchain/retrievers/ensemble.py @@ -89,7 +89,10 @@ class EnsembleRetriever(BaseRetriever): @override def invoke( - self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any + self, + input: str, + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> list[Document]: from langchain_core.callbacks import CallbackManager @@ -123,7 +126,10 @@ class EnsembleRetriever(BaseRetriever): @override async def ainvoke( - self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any + self, + input: str, + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> list[Document]: from langchain_core.callbacks import AsyncCallbackManager @@ -145,7 +151,9 @@ class EnsembleRetriever(BaseRetriever): ) try: result = await self.arank_fusion( - input, run_manager=run_manager, config=config + input, + run_manager=run_manager, + config=config, ) except Exception as e: await run_manager.on_retriever_error(e) @@ -218,7 +226,8 @@ class EnsembleRetriever(BaseRetriever): retriever.invoke( query, patch_config( - config, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}") + config, + callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"), ), ) for i, retriever in enumerate(self.retrievers) @@ -263,7 +272,7 @@ class EnsembleRetriever(BaseRetriever): ), ) for i, retriever in enumerate(self.retrievers) - ] + ], ) # Enforce that retrieved docs are Documents for each list in retriever_docs @@ -277,7 +286,8 @@ class EnsembleRetriever(BaseRetriever): return self.weighted_reciprocal_rank(retriever_docs) def weighted_reciprocal_rank( - self, doc_lists: list[list[Document]] + self, + doc_lists: list[list[Document]], ) -> list[Document]: """ Perform weighted Reciprocal Rank Fusion on multiple rank lists. diff --git a/libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py b/libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py index f310037e092..bb12e286248 100644 --- a/libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py +++ b/libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "GoogleDocumentAIWarehouseRetriever": "langchain_community.retrievers" + "GoogleDocumentAIWarehouseRetriever": "langchain_community.retrievers", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/retrievers/merger_retriever.py b/libs/langchain/langchain/retrievers/merger_retriever.py index ce7663c7c85..bd14debdf47 100644 --- a/libs/langchain/langchain/retrievers/merger_retriever.py +++ b/libs/langchain/langchain/retrievers/merger_retriever.py @@ -53,7 +53,9 @@ class MergerRetriever(BaseRetriever): return await self.amerge_documents(query, run_manager) def merge_documents( - self, query: str, run_manager: CallbackManagerForRetrieverRun + self, + query: str, + run_manager: CallbackManagerForRetrieverRun, ) -> list[Document]: """ Merge the results of the retrievers. @@ -85,7 +87,9 @@ class MergerRetriever(BaseRetriever): return merged_documents async def amerge_documents( - self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun + self, + query: str, + run_manager: AsyncCallbackManagerForRetrieverRun, ) -> list[Document]: """ Asynchronously merge the results of the retrievers. @@ -105,7 +109,7 @@ class MergerRetriever(BaseRetriever): config={"callbacks": run_manager.get_child(f"retriever_{i + 1}")}, ) for i, retriever in enumerate(self.retrievers) - ) + ), ) # Merge the results of the retrievers. diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index d2680adc986..df458c6386a 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -110,7 +110,9 @@ class MultiQueryRetriever(BaseRetriever): return self.unique_union(documents) async def agenerate_queries( - self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun + self, + question: str, + run_manager: AsyncCallbackManagerForRetrieverRun, ) -> list[str]: """Generate queries based upon user input. @@ -121,7 +123,8 @@ class MultiQueryRetriever(BaseRetriever): List of LLM generated queries that are similar to the user input """ response = await self.llm_chain.ainvoke( - {"question": question}, config={"callbacks": run_manager.get_child()} + {"question": question}, + config={"callbacks": run_manager.get_child()}, ) lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response if self.verbose: @@ -129,7 +132,9 @@ class MultiQueryRetriever(BaseRetriever): return lines async def aretrieve_documents( - self, queries: list[str], run_manager: AsyncCallbackManagerForRetrieverRun + self, + queries: list[str], + run_manager: AsyncCallbackManagerForRetrieverRun, ) -> list[Document]: """Run all LLM generated queries. @@ -142,10 +147,11 @@ class MultiQueryRetriever(BaseRetriever): document_lists = await asyncio.gather( *( self.retriever.ainvoke( - query, config={"callbacks": run_manager.get_child()} + query, + config={"callbacks": run_manager.get_child()}, ) for query in queries - ) + ), ) return [doc for docs in document_lists for doc in docs] @@ -170,7 +176,9 @@ class MultiQueryRetriever(BaseRetriever): return self.unique_union(documents) def generate_queries( - self, question: str, run_manager: CallbackManagerForRetrieverRun + self, + question: str, + run_manager: CallbackManagerForRetrieverRun, ) -> list[str]: """Generate queries based upon user input. @@ -181,7 +189,8 @@ class MultiQueryRetriever(BaseRetriever): List of LLM generated queries that are similar to the user input """ response = self.llm_chain.invoke( - {"question": question}, config={"callbacks": run_manager.get_child()} + {"question": question}, + config={"callbacks": run_manager.get_child()}, ) lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response if self.verbose: @@ -189,7 +198,9 @@ class MultiQueryRetriever(BaseRetriever): return lines def retrieve_documents( - self, queries: list[str], run_manager: CallbackManagerForRetrieverRun + self, + queries: list[str], + run_manager: CallbackManagerForRetrieverRun, ) -> list[Document]: """Run all LLM generated queries. @@ -202,7 +213,8 @@ class MultiQueryRetriever(BaseRetriever): documents = [] for query in queries: docs = self.retriever.invoke( - query, config={"callbacks": run_manager.get_child()} + query, + config={"callbacks": run_manager.get_child()}, ) documents.extend(docs) return documents diff --git a/libs/langchain/langchain/retrievers/multi_vector.py b/libs/langchain/langchain/retrievers/multi_vector.py index b6a639cdff7..32ad274d0ef 100644 --- a/libs/langchain/langchain/retrievers/multi_vector.py +++ b/libs/langchain/langchain/retrievers/multi_vector.py @@ -55,7 +55,10 @@ class MultiVectorRetriever(BaseRetriever): return values def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, ) -> list[Document]: """Get documents relevant to a query. Args: @@ -66,12 +69,14 @@ class MultiVectorRetriever(BaseRetriever): """ if self.search_type == SearchType.mmr: sub_docs = self.vectorstore.max_marginal_relevance_search( - query, **self.search_kwargs + query, + **self.search_kwargs, ) elif self.search_type == SearchType.similarity_score_threshold: sub_docs_and_similarities = ( self.vectorstore.similarity_search_with_relevance_scores( - query, **self.search_kwargs + query, + **self.search_kwargs, ) ) sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] @@ -87,7 +92,10 @@ class MultiVectorRetriever(BaseRetriever): return [d for d in docs if d is not None] async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, ) -> list[Document]: """Asynchronously get documents relevant to a query. Args: @@ -98,18 +106,21 @@ class MultiVectorRetriever(BaseRetriever): """ if self.search_type == SearchType.mmr: sub_docs = await self.vectorstore.amax_marginal_relevance_search( - query, **self.search_kwargs + query, + **self.search_kwargs, ) elif self.search_type == SearchType.similarity_score_threshold: sub_docs_and_similarities = ( await self.vectorstore.asimilarity_search_with_relevance_scores( - query, **self.search_kwargs + query, + **self.search_kwargs, ) ) sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities] else: sub_docs = await self.vectorstore.asimilarity_search( - query, **self.search_kwargs + query, + **self.search_kwargs, ) # We do this to maintain the order of the ids that are returned diff --git a/libs/langchain/langchain/retrievers/parent_document_retriever.py b/libs/langchain/langchain/retrievers/parent_document_retriever.py index 5bda553184f..85ca11b97b4 100644 --- a/libs/langchain/langchain/retrievers/parent_document_retriever.py +++ b/libs/langchain/langchain/retrievers/parent_document_retriever.py @@ -132,7 +132,9 @@ class ParentDocumentRetriever(MultiVectorRetriever): and you don't want to re-add them. """ docs, full_docs = self._split_docs_for_adding( - documents, ids, add_to_docstore=add_to_docstore + documents, + ids, + add_to_docstore=add_to_docstore, ) self.vectorstore.add_documents(docs, **kwargs) if add_to_docstore: @@ -146,7 +148,9 @@ class ParentDocumentRetriever(MultiVectorRetriever): **kwargs: Any, ) -> None: docs, full_docs = self._split_docs_for_adding( - documents, ids, add_to_docstore=add_to_docstore + documents, + ids, + add_to_docstore=add_to_docstore, ) await self.vectorstore.aadd_documents(docs, **kwargs) if add_to_docstore: diff --git a/libs/langchain/langchain/retrievers/re_phraser.py b/libs/langchain/langchain/retrievers/re_phraser.py index 9de82e2e2ca..fb75201037a 100644 --- a/libs/langchain/langchain/retrievers/re_phraser.py +++ b/libs/langchain/langchain/retrievers/re_phraser.py @@ -71,11 +71,13 @@ class RePhraseQueryRetriever(BaseRetriever): Relevant documents for re-phrased question """ re_phrased_question = self.llm_chain.invoke( - query, {"callbacks": run_manager.get_child()} + query, + {"callbacks": run_manager.get_child()}, ) logger.info(f"Re-phrased question: {re_phrased_question}") return self.retriever.invoke( - re_phrased_question, config={"callbacks": run_manager.get_child()} + re_phrased_question, + config={"callbacks": run_manager.get_child()}, ) async def _aget_relevant_documents( diff --git a/libs/langchain/langchain/retrievers/self_query/base.py b/libs/langchain/langchain/retrievers/self_query/base.py index 90064311ee0..e0e3a2ac7fc 100644 --- a/libs/langchain/langchain/retrievers/self_query/base.py +++ b/libs/langchain/langchain/retrievers/self_query/base.py @@ -263,7 +263,7 @@ class SelfQueryRetriever(BaseRetriever): """Validate translator.""" if "structured_query_translator" not in values: values["structured_query_translator"] = _get_builtin_translator( - values["vectorstore"] + values["vectorstore"], ) return values @@ -273,10 +273,12 @@ class SelfQueryRetriever(BaseRetriever): return self.query_constructor def _prepare_query( - self, query: str, structured_query: StructuredQuery + self, + query: str, + structured_query: StructuredQuery, ) -> tuple[str, dict[str, Any]]: new_query, new_kwargs = self.structured_query_translator.visit_structured_query( - structured_query + structured_query, ) if structured_query.limit is not None: new_kwargs["k"] = structured_query.limit @@ -286,17 +288,24 @@ class SelfQueryRetriever(BaseRetriever): return new_query, search_kwargs def _get_docs_with_query( - self, query: str, search_kwargs: dict[str, Any] + self, + query: str, + search_kwargs: dict[str, Any], ) -> list[Document]: return self.vectorstore.search(query, self.search_type, **search_kwargs) async def _aget_docs_with_query( - self, query: str, search_kwargs: dict[str, Any] + self, + query: str, + search_kwargs: dict[str, Any], ) -> list[Document]: return await self.vectorstore.asearch(query, self.search_type, **search_kwargs) def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, ) -> list[Document]: """Get documents relevant for a query. @@ -307,7 +316,8 @@ class SelfQueryRetriever(BaseRetriever): List of relevant documents """ structured_query = self.query_constructor.invoke( - {"query": query}, config={"callbacks": run_manager.get_child()} + {"query": query}, + config={"callbacks": run_manager.get_child()}, ) if self.verbose: logger.info(f"Generated Query: {structured_query}") @@ -315,7 +325,10 @@ class SelfQueryRetriever(BaseRetriever): return self._get_docs_with_query(new_query, search_kwargs) async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, ) -> list[Document]: """Get documents relevant for a query. @@ -326,7 +339,8 @@ class SelfQueryRetriever(BaseRetriever): List of relevant documents """ structured_query = await self.query_constructor.ainvoke( - {"query": query}, config={"callbacks": run_manager.get_child()} + {"query": query}, + config={"callbacks": run_manager.get_child()}, ) if self.verbose: logger.info(f"Generated Query: {structured_query}") @@ -372,7 +386,7 @@ class SelfQueryRetriever(BaseRetriever): **chain_kwargs, ) query_constructor = query_constructor.with_config( - run_name=QUERY_CONSTRUCTOR_RUN_NAME + run_name=QUERY_CONSTRUCTOR_RUN_NAME, ) return cls( # type: ignore[call-arg] query_constructor=query_constructor, diff --git a/libs/langchain/langchain/retrievers/time_weighted_retriever.py b/libs/langchain/langchain/retrievers/time_weighted_retriever.py index 3a2dd30d5f7..4bee5cedfd5 100644 --- a/libs/langchain/langchain/retrievers/time_weighted_retriever.py +++ b/libs/langchain/langchain/retrievers/time_weighted_retriever.py @@ -81,7 +81,8 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): """Return documents that are salient to the query.""" docs_and_scores: list[tuple[Document, float]] docs_and_scores = self.vectorstore.similarity_search_with_relevance_scores( - query, **self.search_kwargs + query, + **self.search_kwargs, ) results = {} for fetched_doc, relevance in docs_and_scores: @@ -96,7 +97,8 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): docs_and_scores: list[tuple[Document, float]] docs_and_scores = ( await self.vectorstore.asimilarity_search_with_relevance_scores( - query, **self.search_kwargs + query, + **self.search_kwargs, ) ) results = {} @@ -108,7 +110,8 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): return results def _get_rescored_docs( - self, docs_and_scores: dict[Any, tuple[Document, Optional[float]]] + self, + docs_and_scores: dict[Any, tuple[Document, Optional[float]]], ) -> list[Document]: current_time = datetime.datetime.now() rescored_docs = [ @@ -126,7 +129,10 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): return result def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, ) -> list[Document]: docs_and_scores = { doc.metadata["buffer_idx"]: (doc, self.default_salience) @@ -137,7 +143,10 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): return self._get_rescored_docs(docs_and_scores) async def _aget_relevant_documents( - self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, ) -> list[Document]: docs_and_scores = { doc.metadata["buffer_idx"]: (doc, self.default_salience) @@ -164,7 +173,9 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): return self.vectorstore.add_documents(dup_docs, **kwargs) async def aadd_documents( - self, documents: list[Document], **kwargs: Any + self, + documents: list[Document], + **kwargs: Any, ) -> list[str]: """Add documents to vectorstore.""" current_time = kwargs.get("current_time") diff --git a/libs/langchain/langchain/smith/evaluation/config.py b/libs/langchain/langchain/smith/evaluation/config.py index 74b0c296d56..d8752c80c84 100644 --- a/libs/langchain/langchain/smith/evaluation/config.py +++ b/libs/langchain/langchain/smith/evaluation/config.py @@ -21,7 +21,8 @@ from langchain.evaluation.string_distance.base import ( ) RUN_EVALUATOR_LIKE = Callable[ - [Run, Optional[Example]], Union[EvaluationResult, EvaluationResults, dict] + [Run, Optional[Example]], + Union[EvaluationResult, EvaluationResults, dict], ] BATCH_EVALUATOR_LIKE = Callable[ [Sequence[Run], Optional[Sequence[Example]]], @@ -176,7 +177,9 @@ class RunEvalConfig(BaseModel): evaluator_type: EvaluatorType = EvaluatorType.CRITERIA def __init__( - self, criteria: Optional[CRITERIA_TYPE] = None, **kwargs: Any + self, + criteria: Optional[CRITERIA_TYPE] = None, + **kwargs: Any, ) -> None: super().__init__(criteria=criteria, **kwargs) # type: ignore[call-arg] @@ -196,7 +199,9 @@ class RunEvalConfig(BaseModel): evaluator_type: EvaluatorType = EvaluatorType.LABELED_CRITERIA def __init__( - self, criteria: Optional[CRITERIA_TYPE] = None, **kwargs: Any + self, + criteria: Optional[CRITERIA_TYPE] = None, + **kwargs: Any, ) -> None: super().__init__(criteria=criteria, **kwargs) # type: ignore[call-arg] diff --git a/libs/langchain/langchain/smith/evaluation/progress.py b/libs/langchain/langchain/smith/evaluation/progress.py index 4c1d1101662..7f7ba5f981a 100644 --- a/libs/langchain/langchain/smith/evaluation/progress.py +++ b/libs/langchain/langchain/smith/evaluation/progress.py @@ -14,7 +14,11 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler): """A simple progress bar for the console.""" def __init__( - self, total: int, ncols: int = 50, end_with: str = "\n", **kwargs: Any + self, + total: int, + ncols: int = 50, + end_with: str = "\n", + **kwargs: Any, ): """Initialize the progress bar. diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 45546f864cf..de37aed654a 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -133,7 +133,7 @@ class TestResult(dict): if "reference" in result: if isinstance(result["reference"], dict): r.update( - {f"reference.{k}": v for k, v in result["reference"].items()} + {f"reference.{k}": v for k, v in result["reference"].items()}, ) else: r["reference"] = result["reference"] @@ -143,7 +143,7 @@ class TestResult(dict): "error": result.get("Error"), "execution_time": result["execution_time"], "run_id": result.get("run_id"), - } + }, ) records.append(r) indices.append(example_id) @@ -462,7 +462,7 @@ def _determine_input_key( if run_inputs and input_key not in run_inputs: logger.warning( f"Input key {input_key} not in chain's specified" - f" input keys {run_inputs}. Evaluation behavior may be undefined." + f" input keys {run_inputs}. Evaluation behavior may be undefined.", ) elif run_inputs and len(run_inputs) == 1: input_key = run_inputs[0] @@ -470,7 +470,7 @@ def _determine_input_key( logger.warning( f"Chain expects multiple input keys: {run_inputs}," f" Evaluator is likely to fail. Evaluation behavior may be undefined." - " Specify an input_key in the RunEvalConfig to avoid this warning." + " Specify an input_key in the RunEvalConfig to avoid this warning.", ) return input_key @@ -486,7 +486,7 @@ def _determine_prediction_key( if run_outputs and prediction_key not in run_outputs: logger.warning( f"Prediction key {prediction_key} not in chain's specified" - f" output keys {run_outputs}. Evaluation behavior may be undefined." + f" output keys {run_outputs}. Evaluation behavior may be undefined.", ) elif run_outputs and len(run_outputs) == 1: prediction_key = run_outputs[0] @@ -494,7 +494,7 @@ def _determine_prediction_key( logger.warning( f"Chain expects multiple output keys: {run_outputs}," f" Evaluation behavior may be undefined. Specify a prediction_key" - " in the RunEvalConfig to avoid this warning." + " in the RunEvalConfig to avoid this warning.", ) return prediction_key @@ -623,7 +623,10 @@ def _load_run_evaluators( and any(isinstance(e, StringEvaluator) for e in config.custom_evaluators) ): input_key, prediction_key, reference_key = _get_keys( - config, run_inputs, run_outputs, example_outputs + config, + run_inputs, + run_outputs, + example_outputs, ) for eval_config in config.evaluators: run_evaluator = _construct_run_evaluator( @@ -650,7 +653,7 @@ def _load_run_evaluators( input_key=input_key, prediction_key=prediction_key, reference_key=reference_key, - ) + ), ) elif callable(custom_evaluator): run_evaluators.append(run_evaluator_dec(custom_evaluator)) @@ -700,7 +703,9 @@ async def _arun_llm( return await llm.ainvoke( prompt_or_messages, config=RunnableConfig( - callbacks=callbacks, tags=tags or [], metadata=metadata or {} + callbacks=callbacks, + tags=tags or [], + metadata=metadata or {}, ), ) msg = ( @@ -715,7 +720,9 @@ async def _arun_llm( llm_output: Union[str, BaseMessage] = await llm.ainvoke( prompt, config=RunnableConfig( - callbacks=callbacks, tags=tags or [], metadata=metadata or {} + callbacks=callbacks, + tags=tags or [], + metadata=metadata or {}, ), ) except InputFormatError: @@ -723,7 +730,9 @@ async def _arun_llm( llm_output = await llm.ainvoke( **llm_inputs, config=RunnableConfig( - callbacks=callbacks, tags=tags or [], metadata=metadata or {} + callbacks=callbacks, + tags=tags or [], + metadata=metadata or {}, ), ) return llm_output @@ -750,12 +759,16 @@ async def _arun_chain( output = await chain.ainvoke( val, config=RunnableConfig( - callbacks=callbacks, tags=tags or [], metadata=metadata or {} + callbacks=callbacks, + tags=tags or [], + metadata=metadata or {}, ), ) else: runnable_config = RunnableConfig( - tags=tags or [], callbacks=callbacks, metadata=metadata or {} + tags=tags or [], + callbacks=callbacks, + metadata=metadata or {}, ) output = await chain.ainvoke(inputs_, config=runnable_config) return output @@ -809,7 +822,7 @@ async def _arun_llm_or_chain( logger.warning( f"{chain_or_llm} failed for example {example.id} " f"with inputs {example.inputs}" - f"\n{e!r}" + f"\n{e!r}", ) result = EvalError(Error=e) return result @@ -852,7 +865,9 @@ def _run_llm( llm_output: Union[str, BaseMessage] = llm.invoke( prompt_or_messages, config=RunnableConfig( - callbacks=callbacks, tags=tags or [], metadata=metadata or {} + callbacks=callbacks, + tags=tags or [], + metadata=metadata or {}, ), ) else: @@ -868,7 +883,9 @@ def _run_llm( llm_output = llm.invoke( llm_prompts, config=RunnableConfig( - callbacks=callbacks, tags=tags or [], metadata=metadata or {} + callbacks=callbacks, + tags=tags or [], + metadata=metadata or {}, ), ) except InputFormatError: @@ -901,12 +918,16 @@ def _run_chain( output = chain.invoke( val, config=RunnableConfig( - callbacks=callbacks, tags=tags or [], metadata=metadata or {} + callbacks=callbacks, + tags=tags or [], + metadata=metadata or {}, ), ) else: runnable_config = RunnableConfig( - tags=tags or [], callbacks=callbacks, metadata=metadata or {} + tags=tags or [], + callbacks=callbacks, + metadata=metadata or {}, ) output = chain.invoke(inputs_, config=runnable_config) return output @@ -962,7 +983,7 @@ def _run_llm_or_chain( logger.warning( f"{chain_or_llm} failed for example {example.id} " f"with inputs {example.inputs}" - f"\nError Type: {error_type}, Message: {e}" + f"\nError Type: {error_type}, Message: {e}", ) result = EvalError(Error=e) return result @@ -1104,7 +1125,7 @@ class _DatasetRunContainer: eval_results = callback.logged_eval_results for (_, example_id), v in eval_results.items(): all_eval_results.setdefault(str(example_id), {}).update( - {"feedback": v} + {"feedback": v}, ) elif isinstance(callback, LangChainTracer): run = callback.latest_run @@ -1119,7 +1140,7 @@ class _DatasetRunContainer: "execution_time": execution_time, "run_id": run_id, "run": run, - } + }, ) all_runs[str(callback.example_id)] = run return cast(dict[str, _RowResult], all_eval_results), all_runs @@ -1157,7 +1178,8 @@ class _DatasetRunContainer: try: # Closing the project permits name changing and metric optimizations self.client.update_project( - self.project.id, end_time=datetime.now(timezone.utc) + self.project.id, + end_time=datetime.now(timezone.utc), ) except Exception as e: logger.debug(f"Failed to close project: {e!r}") @@ -1200,7 +1222,10 @@ class _DatasetRunContainer: run_metadata["revision_id"] = revision_id wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory) run_evaluators = _setup_evaluation( - wrapped_model, examples, evaluation, dataset.data_type or DataType.kv + wrapped_model, + examples, + evaluation, + dataset.data_type or DataType.kv, ) _validate_example_inputs(examples[0], wrapped_model, input_mapper) progress_bar = progress.ProgressBarCallback(len(examples)) @@ -1254,7 +1279,8 @@ def _display_aggregate_results(aggregate_results: pd.DataFrame) -> None: display(aggregate_results) else: formatted_string = aggregate_results.to_string( - float_format=lambda x: f"{x:.2f}", justify="right" + float_format=lambda x: f"{x:.2f}", + justify="right", ) print("\n Experiment Results:") # noqa: T201 print(formatted_string) # noqa: T201 @@ -1413,7 +1439,7 @@ def run_on_dataset( ), container.examples, container.configs, - ) + ), ) return container.finish(batch_results, verbose=verbose) @@ -1528,5 +1554,6 @@ or LangSmith's `RunEvaluator` classes. """ # noqa: E501 run_on_dataset.__doc__ = _RUN_ON_DATASET_DOCSTRING arun_on_dataset.__doc__ = _RUN_ON_DATASET_DOCSTRING.replace( - "run_on_dataset(", "await arun_on_dataset(" + "run_on_dataset(", + "await arun_on_dataset(", ) diff --git a/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py b/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py index 0c07887b97e..614eb82ebac 100644 --- a/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py +++ b/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py @@ -231,7 +231,7 @@ class StringExampleMapper(Serializable): return { "reference": self.serialize_chat_messages([output]) if isinstance(output, dict) and output.get("type") and output.get("data") - else output + else output, } def __call__(self, example: Example) -> dict[str, str]: @@ -283,7 +283,9 @@ class StringRunEvaluatorChain(Chain, RunEvaluator): def _prepare_output(self, output: dict[str, Any]) -> dict[str, Any]: evaluation_result = EvaluationResult( - key=self.name, comment=output.get("reasoning"), **output + key=self.name, + comment=output.get("reasoning"), + **output, ) if RUN_KEY in output: # TODO: Not currently surfaced. Update @@ -354,7 +356,8 @@ class StringRunEvaluatorChain(Chain, RunEvaluator): """Evaluate an example.""" try: result = await self.acall( - {"run": run, "example": example}, include_run_info=True + {"run": run, "example": example}, + include_run_info=True, ) return self._prepare_evaluator_output(result) except Exception as e: @@ -405,7 +408,8 @@ class StringRunEvaluatorChain(Chain, RunEvaluator): run_mapper: StringRunMapper = LLMStringRunMapper() elif run_type == "chain": run_mapper = ChainStringRunMapper( - input_key=input_key, prediction_key=prediction_key + input_key=input_key, + prediction_key=prediction_key, ) else: msg = f"Unsupported run type {run_type}. Expected one of 'llm' or 'chain'." diff --git a/libs/langchain/langchain/storage/encoder_backed.py b/libs/langchain/langchain/storage/encoder_backed.py index becc86dd7c2..956d3f6b395 100644 --- a/libs/langchain/langchain/storage/encoder_backed.py +++ b/libs/langchain/langchain/storage/encoder_backed.py @@ -106,7 +106,9 @@ class EncoderBackedStore(BaseStore[K, V]): await self.store.amdelete(encoded_keys) def yield_keys( - self, *, prefix: Optional[str] = None + self, + *, + prefix: Optional[str] = None, ) -> Union[Iterator[K], Iterator[str]]: """Get an iterator over keys that match the given prefix.""" # For the time being this does not return K, but str @@ -114,7 +116,9 @@ class EncoderBackedStore(BaseStore[K, V]): yield from self.store.yield_keys(prefix=prefix) async def ayield_keys( - self, *, prefix: Optional[str] = None + self, + *, + prefix: Optional[str] = None, ) -> Union[AsyncIterator[K], AsyncIterator[str]]: """Get an iterator over keys that match the given prefix.""" # For the time being this does not return K, but str diff --git a/libs/langchain/langchain/tools/google_finance/__init__.py b/libs/langchain/langchain/tools/google_finance/__init__.py index 109a6cf586d..64cda6ba4b7 100644 --- a/libs/langchain/langchain/tools/google_finance/__init__.py +++ b/libs/langchain/langchain/tools/google_finance/__init__.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "GoogleFinanceQueryRun": "langchain_community.tools.google_finance.tool" + "GoogleFinanceQueryRun": "langchain_community.tools.google_finance.tool", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/tools/google_finance/tool.py b/libs/langchain/langchain/tools/google_finance/tool.py index 807d3927e8e..541f335b0ab 100644 --- a/libs/langchain/langchain/tools/google_finance/tool.py +++ b/libs/langchain/langchain/tools/google_finance/tool.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "GoogleFinanceQueryRun": "langchain_community.tools.google_finance.tool" + "GoogleFinanceQueryRun": "langchain_community.tools.google_finance.tool", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/tools/google_scholar/__init__.py b/libs/langchain/langchain/tools/google_scholar/__init__.py index 922b4ee06de..75127ca80b1 100644 --- a/libs/langchain/langchain/tools/google_scholar/__init__.py +++ b/libs/langchain/langchain/tools/google_scholar/__init__.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "GoogleScholarQueryRun": "langchain_community.tools.google_scholar.tool" + "GoogleScholarQueryRun": "langchain_community.tools.google_scholar.tool", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/tools/google_scholar/tool.py b/libs/langchain/langchain/tools/google_scholar/tool.py index b9401e5741f..5fc84c430c9 100644 --- a/libs/langchain/langchain/tools/google_scholar/tool.py +++ b/libs/langchain/langchain/tools/google_scholar/tool.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "GoogleScholarQueryRun": "langchain_community.tools.google_scholar.tool" + "GoogleScholarQueryRun": "langchain_community.tools.google_scholar.tool", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/tools/google_trends/__init__.py b/libs/langchain/langchain/tools/google_trends/__init__.py index ecbbc296e4d..5ebf4eb51bc 100644 --- a/libs/langchain/langchain/tools/google_trends/__init__.py +++ b/libs/langchain/langchain/tools/google_trends/__init__.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "GoogleTrendsQueryRun": "langchain_community.tools.google_trends.tool" + "GoogleTrendsQueryRun": "langchain_community.tools.google_trends.tool", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/tools/google_trends/tool.py b/libs/langchain/langchain/tools/google_trends/tool.py index 8ff47ba2a67..05ce9829f0b 100644 --- a/libs/langchain/langchain/tools/google_trends/tool.py +++ b/libs/langchain/langchain/tools/google_trends/tool.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "GoogleTrendsQueryRun": "langchain_community.tools.google_trends.tool" + "GoogleTrendsQueryRun": "langchain_community.tools.google_trends.tool", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/utilities/dalle_image_generator.py b/libs/langchain/langchain/utilities/dalle_image_generator.py index b62c1034901..31a7ceaf766 100644 --- a/libs/langchain/langchain/utilities/dalle_image_generator.py +++ b/libs/langchain/langchain/utilities/dalle_image_generator.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "DallEAPIWrapper": "langchain_community.utilities.dalle_image_generator" + "DallEAPIWrapper": "langchain_community.utilities.dalle_image_generator", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/utilities/dataforseo_api_search.py b/libs/langchain/langchain/utilities/dataforseo_api_search.py index 14345a56fa7..5436d432790 100644 --- a/libs/langchain/langchain/utilities/dataforseo_api_search.py +++ b/libs/langchain/langchain/utilities/dataforseo_api_search.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "DataForSeoAPIWrapper": "langchain_community.utilities.dataforseo_api_search" + "DataForSeoAPIWrapper": "langchain_community.utilities.dataforseo_api_search", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/utilities/reddit_search.py b/libs/langchain/langchain/utilities/reddit_search.py index fa8288ad7f3..4e61040fcd4 100644 --- a/libs/langchain/langchain/utilities/reddit_search.py +++ b/libs/langchain/langchain/utilities/reddit_search.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "RedditSearchAPIWrapper": "langchain_community.utilities.reddit_search" + "RedditSearchAPIWrapper": "langchain_community.utilities.reddit_search", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/langchain/utilities/tavily_search.py b/libs/langchain/langchain/utilities/tavily_search.py index e3bae5866ff..b552c10607f 100644 --- a/libs/langchain/langchain/utilities/tavily_search.py +++ b/libs/langchain/langchain/utilities/tavily_search.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: # Used to consolidate logic for raising deprecation warnings and # handling optional imports. DEPRECATED_LOOKUP = { - "TavilySearchAPIWrapper": "langchain_community.utilities.tavily_search" + "TavilySearchAPIWrapper": "langchain_community.utilities.tavily_search", } _import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 9827a2b058c..6aefa38a9c3 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -143,8 +143,58 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*" ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin" [tool.ruff.lint] -select = ["A", "C4", "D", "E", "EM", "F", "FBT", "I", "PGH003", "PIE", "RET", "RUF", "S", "SIM", "T201", "UP", "W"] -ignore = ["RUF012"] +select = [ + "A", # flake8-builtins + "ASYNC", # flake8-async + "C4", # flake8-comprehensions + "COM", # flake8-commas + "D", # pydocstyle + "DOC", # pydoclint + "E", # pycodestyle error + "EM", # flake8-errmsg + "F", # pyflakes + "FA", # flake8-future-annotations + "FBT", # flake8-boolean-trap + "FLY", # flake8-flynt + "I", # isort + "ICN", # flake8-import-conventions + "INT", # flake8-gettext + "ISC", # isort-comprehensions + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PERF", # flake8-perf + "PYI", # flake8-pyi + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-rst-docstrings + "RUF", # ruff + "S", # flake8-bandit + "SLF", # flake8-self + "SLOT", # flake8-slots + "SIM", # flake8-simplify + "T10", # flake8-debugger + "T20", # flake8-print + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle warning + "YTT", # flake8-2020 +] +ignore = [ + "D100", + "D101", + "D102", + "D103", + "D104", + "D105", + "D107", + "COM812", # Messes with the formatter + "ISC001", # Messes with the formatter + "PERF203", # Rarely useful + "S112", # Rarely useful + "RUF012", # Doesn't play well with Pydantic + "SLF001", # Private member access +] + pydocstyle.convention = "google" pyupgrade.keep-runtime-typing = true @@ -154,8 +204,12 @@ pyupgrade.keep-runtime-typing = true [tool.ruff.lint.extend-per-file-ignores] "tests/**/*.py" = [ - "S101", # Tests need assertions - "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "S101", # Tests need assertions + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "SLF001", # Private member access in tests +] +"langchain/chains/constitutional_ai/principles.py" = [ + "E501", # Line too long ] [tool.coverage.run] diff --git a/libs/langchain/tests/integration_tests/cache/fake_embeddings.py b/libs/langchain/tests/integration_tests/cache/fake_embeddings.py index e04f295782b..9f318b60cfe 100644 --- a/libs/langchain/tests/integration_tests/cache/fake_embeddings.py +++ b/libs/langchain/tests/integration_tests/cache/fake_embeddings.py @@ -44,7 +44,7 @@ class ConsistentFakeEmbeddings(FakeEmbeddings): if text not in self.known_texts: self.known_texts.append(text) vector = [1.0] * (self.dimensionality - 1) + [ - float(self.known_texts.index(text)) + float(self.known_texts.index(text)), ] out_vectors.append(vector) return out_vectors diff --git a/libs/langchain/tests/integration_tests/chains/openai_functions/test_openapi.py b/libs/langchain/tests/integration_tests/chains/openai_functions/test_openapi.py index 9db8b957698..9bc89d36fad 100644 --- a/libs/langchain/tests/integration_tests/chains/openai_functions/test_openapi.py +++ b/libs/langchain/tests/integration_tests/chains/openai_functions/test_openapi.py @@ -22,8 +22,8 @@ api_spec = { "description": "Limit the number of results", }, ], - } - } + }, + }, }, } diff --git a/libs/langchain/tests/integration_tests/chat_models/test_base.py b/libs/langchain/tests/integration_tests/chat_models/test_base.py index 34ce44990ae..1c915f7f588 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_base.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_base.py @@ -31,9 +31,9 @@ async def test_init_chat_model_chain() -> None: chain = prompt | model_with_config output = chain.invoke({"input": "bar"}) assert isinstance(output, AIMessage) - events = [] - async for event in chain.astream_events({"input": "bar"}, version="v2"): - events.append(event) + events = [ + event async for event in chain.astream_events({"input": "bar"}, version="v2") + ] assert events diff --git a/libs/langchain/tests/integration_tests/evaluation/embedding_distance/test_embedding.py b/libs/langchain/tests/integration_tests/evaluation/embedding_distance/test_embedding.py index 1de7b3ec93b..1f8318d429c 100644 --- a/libs/langchain/tests/integration_tests/evaluation/embedding_distance/test_embedding.py +++ b/libs/langchain/tests/integration_tests/evaluation/embedding_distance/test_embedding.py @@ -23,7 +23,7 @@ def vectors() -> tuple[np.ndarray, np.ndarray]: 0.891773, 0.96366276, 0.38344152, - ] + ], ) vector_b = np.array( [ @@ -37,7 +37,7 @@ def vectors() -> tuple[np.ndarray, np.ndarray]: 0.83261985, 0.77815675, 0.87001215, - ] + ], ) return vector_a, vector_b @@ -130,7 +130,8 @@ def test_pairwise_embedding_distance_eval_chain_embedding_distance( ) -> None: """Test the embedding distance.""" result = pairwise_embedding_distance_eval_chain.evaluate_string_pairs( - prediction="A single cat", prediction_b="A single cat" + prediction="A single cat", + prediction_b="A single cat", ) assert np.isclose(result["score"], 0.0) diff --git a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_listwise_rerank.py b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_listwise_rerank.py index b7e6496dd8b..135073c61c7 100644 --- a/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_listwise_rerank.py +++ b/libs/langchain/tests/integration_tests/retrievers/document_compressors/test_listwise_rerank.py @@ -15,7 +15,8 @@ def test_list_rerank() -> None: ] reranker = LLMListwiseRerank.from_llm( - llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3 + llm=ChatOpenAI(model="gpt-3.5-turbo"), + top_n=3, ) compressed_docs = reranker.compress_documents(documents, "Who is steve") assert len(compressed_docs) == 3 diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log.py index edcd1b59139..7ae553b5701 100644 --- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log.py +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log.py @@ -5,7 +5,7 @@ from langchain.agents.format_scratchpad.log import format_log_to_str def test_single_agent_action_observation() -> None: intermediate_steps = [ - (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1") + (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1"), ] expected_result = "Log1\nObservation: Observation1\nThought: " assert format_log_to_str(intermediate_steps) == expected_result @@ -25,7 +25,7 @@ Observation3\nThought: """ def test_custom_prefixes() -> None: intermediate_steps = [ - (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1") + (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1"), ] observation_prefix = "Custom Observation: " llm_prefix = "Custom Thought: " diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log_to_messages.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log_to_messages.py index deb6eaceed0..1cf6044c409 100644 --- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log_to_messages.py +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_log_to_messages.py @@ -6,7 +6,7 @@ from langchain.agents.format_scratchpad.log_to_messages import format_log_to_mes def test_single_intermediate_step_default_response() -> None: intermediate_steps = [ - (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1") + (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1"), ] expected_result = [AIMessage(content="Log1"), HumanMessage(content="Observation1")] assert format_log_to_messages(intermediate_steps) == expected_result @@ -31,7 +31,7 @@ def test_multiple_intermediate_steps_default_response() -> None: def test_custom_template_tool_response() -> None: intermediate_steps = [ - (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1") + (AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1"), ] template_tool_response = "Response: {observation}" expected_result = [ @@ -40,7 +40,8 @@ def test_custom_template_tool_response() -> None: ] assert ( format_log_to_messages( - intermediate_steps, template_tool_response=template_tool_response + intermediate_steps, + template_tool_response=template_tool_response, ) == expected_result ) diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_functions.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_functions.py index 031721fc247..ff804a2d723 100644 --- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_functions.py @@ -11,32 +11,41 @@ def test_calls_convert_agent_action_to_messages() -> None: "function_call": { "name": "tool1", "arguments": "input1", - } + }, } message1 = AIMessage(content="", additional_kwargs=additional_kwargs1) action1 = AgentActionMessageLog( - tool="tool1", tool_input="input1", log="log1", message_log=[message1] + tool="tool1", + tool_input="input1", + log="log1", + message_log=[message1], ) additional_kwargs2 = { "function_call": { "name": "tool2", "arguments": "input2", - } + }, } message2 = AIMessage(content="", additional_kwargs=additional_kwargs2) action2 = AgentActionMessageLog( - tool="tool2", tool_input="input2", log="log2", message_log=[message2] + tool="tool2", + tool_input="input2", + log="log2", + message_log=[message2], ) additional_kwargs3 = { "function_call": { "name": "tool3", "arguments": "input3", - } + }, } message3 = AIMessage(content="", additional_kwargs=additional_kwargs3) action3 = AgentActionMessageLog( - tool="tool3", tool_input="input3", log="log3", message_log=[message3] + tool="tool3", + tool_input="input3", + log="log3", + message_log=[message3], ) intermediate_steps = [ diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py index 2753571b4be..04d33d12f4e 100644 --- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py @@ -15,7 +15,7 @@ def test_calls_convert_agent_action_to_messages() -> None: "id": "call_abcd12345", "function": {"arguments": '{"a": 3, "b": 5}', "name": "add"}, "type": "function", - } + }, ], } message1 = AIMessage(content="", additional_kwargs=additional_kwargs1) @@ -27,7 +27,7 @@ def test_calls_convert_agent_action_to_messages() -> None: "id": "call_abcd54321", "function": {"arguments": '{"a": 3, "b": 5}', "name": "subtract"}, "type": "function", - } + }, ], } message2 = AIMessage(content="", additional_kwargs=additional_kwargs2) @@ -53,7 +53,7 @@ def test_calls_convert_agent_action_to_messages() -> None: message4 = AIMessage( content="", tool_calls=[ - ToolCall(name="exponentiate", args={"a": 3, "b": 5}, id="call_abc02468") + ToolCall(name="exponentiate", args={"a": 3, "b": 5}, id="call_abc02468"), ], ) actions4 = parse_ai_message_to_openai_tool_action(message4) diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_openai_functions.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_openai_functions.py index c251d2ea9c9..7dc4188c79e 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_openai_functions.py @@ -35,7 +35,7 @@ def test_func_call() -> None: msg = AIMessage( content="LLM thoughts.", additional_kwargs={ - "function_call": {"name": "foo", "arguments": '{"param": 42}'} + "function_call": {"name": "foo", "arguments": '{"param": 42}'}, }, ) result = parser.invoke(msg) @@ -71,7 +71,7 @@ def test_func_call_oldstyle() -> None: msg = AIMessage( content="LLM thoughts.", additional_kwargs={ - "function_call": {"name": "foo", "arguments": '{"__arg1": "42"}'} + "function_call": {"name": "foo", "arguments": '{"__arg1": "42"}'}, }, ) result = parser.invoke(msg) diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_json_single_input.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_json_single_input.py index 935ff301747..b0277c602dd 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_json_single_input.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_json_single_input.py @@ -18,7 +18,9 @@ def test_action() -> None: """ output = parser.invoke(_input) expected_output = AgentAction( - tool="search", tool_input="what is the temperature in SF?", log=_input + tool="search", + tool_input="what is the temperature in SF?", + log=_input, ) assert output == expected_output @@ -30,6 +32,7 @@ def test_finish() -> None: Final Answer: The temperature is 100""" output = parser.invoke(_input) expected_output = AgentFinish( - return_values={"output": "The temperature is 100"}, log=_input + return_values={"output": "The temperature is 100"}, + log=_input, ) assert output == expected_output diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_single_input.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_single_input.py index 07365ccb5d3..1eb3855d6f7 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_single_input.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_single_input.py @@ -15,7 +15,9 @@ Action: search Action Input: what is the temperature in SF?""" output = parser.invoke(_input) expected_output = AgentAction( - tool="search", tool_input="what is the temperature in SF?", log=_input + tool="search", + tool_input="what is the temperature in SF?", + log=_input, ) assert output == expected_output @@ -27,7 +29,8 @@ def test_finish() -> None: Final Answer: The temperature is 100""" output = parser.invoke(_input) expected_output = AgentFinish( - return_values={"output": "The temperature is 100"}, log=_input + return_values={"output": "The temperature is 100"}, + log=_input, ) assert output == expected_output diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_self_ask.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_self_ask.py index c5d5658b064..6e93c3ad4b0 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_self_ask.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_self_ask.py @@ -9,14 +9,18 @@ def test_follow_up() -> None: _input = "Follow up: what is two + 2" output = parser.invoke(_input) expected_output = AgentAction( - tool="Intermediate Answer", tool_input="what is two + 2", log=_input + tool="Intermediate Answer", + tool_input="what is two + 2", + log=_input, ) assert output == expected_output # Test that also handles one word by default _input = "Followup: what is two + 2" output = parser.invoke(_input) expected_output = AgentAction( - tool="Intermediate Answer", tool_input="what is two + 2", log=_input + tool="Intermediate Answer", + tool_input="what is two + 2", + log=_input, ) assert output == expected_output @@ -27,7 +31,9 @@ def test_follow_up_custom() -> None: _input = "Now: what is two + 2" output = parser.invoke(_input) expected_output = AgentAction( - tool="Intermediate Answer", tool_input="what is two + 2", log=_input + tool="Intermediate Answer", + tool_input="what is two + 2", + log=_input, ) assert output == expected_output diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index 3008d40a268..96bf1c3650c 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -212,12 +212,12 @@ def test_agent_stream() -> None: tool="Search", tool_input="misalignment", log="FooBarBaz\nAction: Search\nAction Input: misalignment", - ) + ), ], "messages": [ AIMessage( - content="FooBarBaz\nAction: Search\nAction Input: misalignment" - ) + content="FooBarBaz\nAction: Search\nAction Input: misalignment", + ), ], }, { @@ -229,7 +229,7 @@ def test_agent_stream() -> None: log="FooBarBaz\nAction: Search\nAction Input: misalignment", ), observation="Results for: misalignment", - ) + ), ], "messages": [HumanMessage(content="Results for: misalignment")], }, @@ -239,12 +239,12 @@ def test_agent_stream() -> None: tool="Search", tool_input="something else", log="FooBarBaz\nAction: Search\nAction Input: something else", - ) + ), ], "messages": [ AIMessage( - content="FooBarBaz\nAction: Search\nAction Input: something else" - ) + content="FooBarBaz\nAction: Search\nAction Input: something else", + ), ], }, { @@ -256,14 +256,14 @@ def test_agent_stream() -> None: log="FooBarBaz\nAction: Search\nAction Input: something else", ), observation="Results for: something else", - ) + ), ], "messages": [HumanMessage(content="Results for: something else")], }, { "output": "curses foiled again", "messages": [ - AIMessage(content="Oh well\nFinal Answer: curses foiled again") + AIMessage(content="Oh well\nFinal Answer: curses foiled again"), ], }, ] @@ -302,7 +302,7 @@ def test_agent_stream() -> None: AIMessage(content="FooBarBaz\nAction: Search\nAction Input: misalignment"), HumanMessage(content="Results for: misalignment"), AIMessage( - content="FooBarBaz\nAction: Search\nAction Input: something else" + content="FooBarBaz\nAction: Search\nAction Input: something else", ), HumanMessage(content="Results for: something else"), AIMessage(content="Oh well\nFinal Answer: curses foiled again"), @@ -371,7 +371,7 @@ def test_agent_tool_return_direct_in_intermediate_steps() -> None: def test_agent_with_new_prefix_suffix() -> None: """Test agent initialization kwargs with new prefix and suffix.""" fake_llm = FakeListLLM( - responses=["FooBarBaz\nAction: Search\nAction Input: misalignment"] + responses=["FooBarBaz\nAction: Search\nAction Input: misalignment"], ) tools = [ Tool( @@ -402,7 +402,7 @@ def test_agent_with_new_prefix_suffix() -> None: def test_agent_lookup_tool() -> None: """Test agent lookup tool.""" fake_llm = FakeListLLM( - responses=["FooBarBaz\nAction: Search\nAction Input: misalignment"] + responses=["FooBarBaz\nAction: Search\nAction Input: misalignment"], ) tools = [ Tool( @@ -456,7 +456,7 @@ async def test_runnable_agent() -> None: [ ("system", "You are Cat Agent 007"), ("human", "{question}"), - ] + ], ) def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: @@ -476,7 +476,8 @@ async def test_runnable_agent() -> None: # Batch result = await asyncio.to_thread( - executor.batch, [{"question": "hello"}, {"question": "hello"}] + executor.batch, + [{"question": "hello"}, {"question": "hello"}], ) assert result == [ {"foo": "meow", "question": "hello"}, @@ -493,7 +494,7 @@ async def test_runnable_agent() -> None: # Stream results = await asyncio.to_thread(list, executor.stream({"question": "hello"})) assert results == [ - {"foo": "meow", "messages": [AIMessage(content="hard-coded-message")]} + {"foo": "meow", "messages": [AIMessage(content="hard-coded-message")]}, ] # astream @@ -504,7 +505,7 @@ async def test_runnable_agent() -> None: "messages": [ AIMessage(content="hard-coded-message"), ], - } + }, ] # stream log @@ -516,7 +517,7 @@ async def test_runnable_agent() -> None: for log_record in results: for op in log_record.ops: # type: ignore[attr-defined] if op["op"] == "add" and isinstance(op["value"], AIMessageChunk): - messages.append(op["value"]) + messages.append(op["value"]) # noqa: PERF401 assert messages != [] @@ -538,7 +539,7 @@ async def test_runnable_agent_with_function_calls() -> None: [ AIMessage(content="looking for pet..."), AIMessage(content="Found Pet"), - ] + ], ) model = GenericFakeChatModel(messages=infinite_cycle) @@ -546,7 +547,7 @@ async def test_runnable_agent_with_function_calls() -> None: [ ("system", "You are Cat Agent 007"), ("human", "{question}"), - ] + ], ) parser_responses = cycle( @@ -594,8 +595,10 @@ async def test_runnable_agent_with_function_calls() -> None: { "actions": [ AgentAction( - tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()" - ) + tool="find_pet", + tool_input={"pet": "cat"}, + log="find_pet()", + ), ], "messages": [AIMessage(content="find_pet()")], }, @@ -604,10 +607,12 @@ async def test_runnable_agent_with_function_calls() -> None: "steps": [ AgentStep( action=AgentAction( - tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()" + tool="find_pet", + tool_input={"pet": "cat"}, + log="find_pet()", ), observation="Spying from under the bed.", - ) + ), ], }, {"foo": "meow", "messages": [AIMessage(content="hard-coded-message")]}, @@ -617,19 +622,15 @@ async def test_runnable_agent_with_function_calls() -> None: messages = [] async for patch in executor.astream_log({"question": "hello"}): - for op in patch.ops: - if op["op"] != "add": - continue - - value = op["value"] - - if not isinstance(value, AIMessageChunk): - continue - - if value.content == "": # Then it's a function invocation message - continue - - messages.append(value.content) + messages.extend( + [ + op["value"].content + for op in patch.ops + if op["op"] == "add" + and isinstance(op["value"], AIMessageChunk) + and op["value"].content != "" + ] + ) assert messages == ["looking", " ", "for", " ", "pet...", "Found", " ", "Pet"] @@ -641,7 +642,7 @@ async def test_runnable_with_multi_action_per_step() -> None: [ AIMessage(content="looking for pet..."), AIMessage(content="Found Pet"), - ] + ], ) model = GenericFakeChatModel(messages=infinite_cycle) @@ -649,7 +650,7 @@ async def test_runnable_with_multi_action_per_step() -> None: [ ("system", "You are Cat Agent 007"), ("human", "{question}"), - ] + ], ) parser_responses = cycle( @@ -714,14 +715,16 @@ async def test_runnable_with_multi_action_per_step() -> None: { "actions": [ AgentAction( - tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()" - ) + tool="find_pet", + tool_input={"pet": "cat"}, + log="find_pet()", + ), ], "messages": [AIMessage(content="find_pet()")], }, { "actions": [ - AgentAction(tool="pet_pet", tool_input={"pet": "cat"}, log="pet_pet()") + AgentAction(tool="pet_pet", tool_input={"pet": "cat"}, log="pet_pet()"), ], "messages": [AIMessage(content="pet_pet()")], }, @@ -731,25 +734,29 @@ async def test_runnable_with_multi_action_per_step() -> None: "steps": [ AgentStep( action=AgentAction( - tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()" + tool="find_pet", + tool_input={"pet": "cat"}, + log="find_pet()", ), observation="Spying from under the bed.", - ) + ), ], }, { "messages": [ HumanMessage( - content="pet_pet is not a valid tool, try one of [find_pet]." - ) + content="pet_pet is not a valid tool, try one of [find_pet].", + ), ], "steps": [ AgentStep( action=AgentAction( - tool="pet_pet", tool_input={"pet": "cat"}, log="pet_pet()" + tool="pet_pet", + tool_input={"pet": "cat"}, + log="pet_pet()", ), observation="pet_pet is not a valid tool, try one of [find_pet].", - ) + ), ], }, {"foo": "meow", "messages": [AIMessage(content="hard-coded-message")]}, @@ -792,7 +799,7 @@ def _make_func_invocation(name: str, **kwargs: Any) -> AIMessage: "function_call": { "name": name, "arguments": json.dumps(kwargs), - } + }, }, ) @@ -823,7 +830,7 @@ async def test_openai_agent_with_streaming() -> None: [ _make_func_invocation("find_pet", pet="cat"), AIMessage(content="The cat is spying from under the bed."), - ] + ], ) model = GenericFakeChatModel(messages=infinite_cycle) @@ -843,7 +850,7 @@ async def test_openai_agent_with_streaming() -> None: MessagesPlaceholder( variable_name="agent_scratchpad", ), - ] + ], ) # type error due to base tool type below -- would need to be adjusted on tool @@ -875,18 +882,18 @@ async def test_openai_agent_with_streaming() -> None: "function_call": { "arguments": '{"pet": "cat"}', "name": "find_pet", - } + }, }, "content": "", "name": None, "response_metadata": {}, "type": "AIMessageChunk", - } + }, ], "tool": "find_pet", "tool_input": {"pet": "cat"}, "type": "AgentActionMessageLog", - } + }, ], "messages": [ { @@ -894,7 +901,7 @@ async def test_openai_agent_with_streaming() -> None: "function_call": { "arguments": '{"pet": "cat"}', "name": "find_pet", - } + }, }, "content": "", "example": False, @@ -905,7 +912,7 @@ async def test_openai_agent_with_streaming() -> None: "tool_calls": [], "type": "AIMessageChunk", "usage_metadata": None, - } + }, ], }, { @@ -916,7 +923,7 @@ async def test_openai_agent_with_streaming() -> None: "name": "find_pet", "response_metadata": {}, "type": "function", - } + }, ], "steps": [ { @@ -927,7 +934,7 @@ async def test_openai_agent_with_streaming() -> None: "type": "AgentActionMessageLog", }, "observation": "Spying from under the bed.", - } + }, ], }, { @@ -942,7 +949,7 @@ async def test_openai_agent_with_streaming() -> None: "tool_calls": [], "type": "ai", "usage_metadata": None, - } + }, ], "output": "The cat is spying from under the bed.", }, @@ -1016,10 +1023,10 @@ async def test_openai_agent_tools_agent() -> None: { "find_pet": {"pet": "cat"}, "check_time": {}, - } + }, ), AIMessage(content="The cat is spying from under the bed."), - ] + ], ) GenericFakeChatModel.bind_tools = lambda self, x: self # type: ignore[assignment,misc] @@ -1045,7 +1052,7 @@ async def test_openai_agent_tools_agent() -> None: MessagesPlaceholder( variable_name="agent_scratchpad", ), - ] + ], ) # type error due to base tool type below -- would need to be adjusted on tool @@ -1098,12 +1105,12 @@ async def test_openai_agent_tools_agent() -> None: }, "id": "1", }, - ] + ], }, - ) + ), ], tool_call_id="0", - ) + ), ], "messages": [ _AnyIdAIMessageChunk( @@ -1124,9 +1131,9 @@ async def test_openai_agent_tools_agent() -> None: }, "id": "1", }, - ] + ], }, - ) + ), ], }, { @@ -1154,12 +1161,12 @@ async def test_openai_agent_tools_agent() -> None: }, "id": "1", }, - ] + ], }, - ) + ), ], tool_call_id="1", - ) + ), ], "messages": [ _AnyIdAIMessageChunk( @@ -1180,16 +1187,17 @@ async def test_openai_agent_tools_agent() -> None: }, "id": "1", }, - ] + ], }, - ) + ), ], }, { "messages": [ FunctionMessage( - content="Spying from under the bed.", name="find_pet" - ) + content="Spying from under the bed.", + name="find_pet", + ), ], "steps": [ AgentStep( @@ -1216,14 +1224,14 @@ async def test_openai_agent_tools_agent() -> None: }, "id": "1", }, - ] + ], }, - ) + ), ], tool_call_id="0", ), observation="Spying from under the bed.", - ) + ), ], }, { @@ -1231,7 +1239,7 @@ async def test_openai_agent_tools_agent() -> None: FunctionMessage( content="check_time is not a valid tool, try one of [find_pet].", # noqa: E501 name="check_time", - ) + ), ], "steps": [ AgentStep( @@ -1258,20 +1266,20 @@ async def test_openai_agent_tools_agent() -> None: }, "id": "1", }, - ] + ], }, - ) + ), ], tool_call_id="1", ), observation="check_time is not a valid tool, " "try one of [find_pet].", - ) + ), ], }, { "messages": [ - AIMessage(content="The cat is spying from under the bed.") + AIMessage(content="The cat is spying from under the bed."), ], "output": "The cat is spying from under the bed.", }, diff --git a/libs/langchain/tests/unit_tests/agents/test_agent_async.py b/libs/langchain/tests/unit_tests/agents/test_agent_async.py index 62c1039ed7b..a9af6e4c351 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent_async.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent_async.py @@ -183,12 +183,12 @@ async def test_agent_stream() -> None: tool="Search", tool_input="misalignment", log="FooBarBaz\nAction: Search\nAction Input: misalignment", - ) + ), ], "messages": [ AIMessage( - content="FooBarBaz\nAction: Search\nAction Input: misalignment" - ) + content="FooBarBaz\nAction: Search\nAction Input: misalignment", + ), ], }, { @@ -200,7 +200,7 @@ async def test_agent_stream() -> None: log="FooBarBaz\nAction: Search\nAction Input: misalignment", ), observation="Results for: misalignment", - ) + ), ], "messages": [HumanMessage(content="Results for: misalignment")], }, @@ -210,12 +210,12 @@ async def test_agent_stream() -> None: tool="Search", tool_input="something else", log="FooBarBaz\nAction: Search\nAction Input: something else", - ) + ), ], "messages": [ AIMessage( - content="FooBarBaz\nAction: Search\nAction Input: something else" - ) + content="FooBarBaz\nAction: Search\nAction Input: something else", + ), ], }, { @@ -227,14 +227,14 @@ async def test_agent_stream() -> None: log="FooBarBaz\nAction: Search\nAction Input: something else", ), observation="Results for: something else", - ) + ), ], "messages": [HumanMessage(content="Results for: something else")], }, { "output": "curses foiled again", "messages": [ - AIMessage(content="Oh well\nFinal Answer: curses foiled again") + AIMessage(content="Oh well\nFinal Answer: curses foiled again"), ], }, ] @@ -273,7 +273,7 @@ async def test_agent_stream() -> None: AIMessage(content="FooBarBaz\nAction: Search\nAction Input: misalignment"), HumanMessage(content="Results for: misalignment"), AIMessage( - content="FooBarBaz\nAction: Search\nAction Input: something else" + content="FooBarBaz\nAction: Search\nAction Input: something else", ), HumanMessage(content="Results for: something else"), AIMessage(content="Oh well\nFinal Answer: curses foiled again"), diff --git a/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py b/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py index ebd52ff92c5..f1c01b7630e 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py @@ -21,9 +21,7 @@ def test_agent_iterator_bad_action() -> None: agent = _get_agent() agent_iter = agent.iter(inputs="when was langchain made") - outputs = [] - for step in agent_iter: - outputs.append(step) + outputs = list(agent_iter) assert isinstance(outputs[-1], dict) assert outputs[-1]["output"] == "curses foiled again" @@ -38,9 +36,7 @@ def test_agent_iterator_stopped_early() -> None: agent = _get_agent(max_iterations=1) agent_iter = agent.iter(inputs="when was langchain made") - outputs = [] - for step in agent_iter: - outputs.append(step) + outputs = list(agent_iter) # NOTE: we don't use agent.run like in the test for the regular agent executor, # so the dict structure for outputs stays intact assert isinstance(outputs[-1], dict) @@ -70,10 +66,8 @@ async def test_agent_async_iterator_stopped_early() -> None: agent = _get_agent(max_iterations=1) agent_async_iter = agent.iter(inputs="when was langchain made") - outputs = [] assert isinstance(agent_async_iter, AgentExecutorIterator) - async for step in agent_async_iter: - outputs.append(step) + outputs = list(agent_async_iter) assert isinstance(outputs[-1], dict) assert ( @@ -125,12 +119,12 @@ def test_agent_iterator_with_callbacks() -> None: verbose=True, ) agent_iter = agent.iter( - inputs="when was langchain made", callbacks=[handler1], include_run_info=True + inputs="when was langchain made", + callbacks=[handler1], + include_run_info=True, ) - outputs = [] - for step in agent_iter: - outputs.append(step) + outputs = list(agent_iter) assert isinstance(outputs[-1], dict) assert outputs[-1]["output"] == "curses foiled again" assert isinstance(outputs[-1][RUN_KEY].run_id, UUID) @@ -186,7 +180,10 @@ async def test_agent_async_iterator_with_callbacks() -> None: ] agent = initialize_agent( - tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + tools, + fake_llm, + agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, + verbose=True, ) agent_async_iter = agent.iter( inputs="when was langchain made", @@ -195,9 +192,7 @@ async def test_agent_async_iterator_with_callbacks() -> None: ) assert isinstance(agent_async_iter, AgentExecutorIterator) - outputs = [] - async for step in agent_async_iter: - outputs.append(step) + outputs = list(agent_async_iter) assert outputs[-1]["output"] == "curses foiled again" assert isinstance(outputs[-1][RUN_KEY].run_id, UUID) @@ -334,9 +329,7 @@ def test_agent_iterator_empty_input() -> None: agent = _get_agent() agent_iter = agent.iter(inputs="") - outputs = [] - for step in agent_iter: - outputs.append(step) + outputs = list(agent_iter) assert isinstance(outputs[-1], dict) assert outputs[-1]["output"] # Check if there is an output @@ -352,9 +345,7 @@ def test_agent_iterator_custom_stopping_condition() -> None: agent_iter = CustomAgentExecutorIterator(agent, inputs="when was langchain made") - outputs = [] - for step in agent_iter: - outputs.append(step) + outputs = list(agent_iter) assert len(outputs) == 2 # Check if the custom stopping condition is respected @@ -379,7 +370,10 @@ def test_agent_iterator_failing_tool() -> None: ] agent = initialize_agent( - tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + tools, + fake_llm, + agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, + verbose=True, ) agent_iter = agent.iter(inputs="when was langchain made") diff --git a/libs/langchain/tests/unit_tests/agents/test_mrkl.py b/libs/langchain/tests/unit_tests/agents/test_mrkl.py index a464b1eea92..3bb5678d5c1 100644 --- a/libs/langchain/tests/unit_tests/agents/test_mrkl.py +++ b/libs/langchain/tests/unit_tests/agents/test_mrkl.py @@ -153,7 +153,7 @@ def test_from_chains() -> None: expected_tools_prompt, FORMAT_INSTRUCTIONS.format(tool_names=expected_tool_names), SUFFIX, - ] + ], ) prompt = agent.llm_chain.prompt assert isinstance(prompt, PromptTemplate) diff --git a/libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py b/libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py index 7d18fde1222..83bd63b9a7a 100644 --- a/libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py +++ b/libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py @@ -35,7 +35,7 @@ class TestParseAIMessage: msg = AIMessage( content="LLM thoughts.", additional_kwargs={ - "function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'} + "function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'}, }, ) result = _parse_ai_message(msg) @@ -59,7 +59,7 @@ class TestParseAIMessage: msg = AIMessage( content="LLM thoughts.", additional_kwargs={ - "function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'} + "function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'}, }, ) result = _parse_ai_message(msg) diff --git a/libs/langchain/tests/unit_tests/agents/test_structured_chat.py b/libs/langchain/tests/unit_tests/agents/test_structured_chat.py index 61bb6cdee4c..18c833fd4e1 100644 --- a/libs/langchain/tests/unit_tests/agents/test_structured_chat.py +++ b/libs/langchain/tests/unit_tests/agents/test_structured_chat.py @@ -121,7 +121,7 @@ class TestCreatePrompt: # Test: Output should be a ChatPromptTemplate with sys and human messages. def test_create_prompt_output(self) -> None: prompt = StructuredChatAgent.create_prompt( - [Tool(name="foo", description="Test tool FOO", func=lambda x: x)] + [Tool(name="foo", description="Test tool FOO", func=lambda x: x)], ) assert isinstance(prompt, ChatPromptTemplate) @@ -132,7 +132,7 @@ class TestCreatePrompt: # Test: Format with a single tool. def test_system_message_single_tool(self) -> None: prompt: Any = StructuredChatAgent.create_prompt( - [Tool(name="foo", description="Test tool FOO", func=lambda x: x)] + [Tool(name="foo", description="Test tool FOO", func=lambda x: x)], ) actual = prompt.messages[0].prompt.format() @@ -176,7 +176,7 @@ class TestCreatePrompt: Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. Thought: - """ # noqa: E501 + """, # noqa: E501 ).strip() assert actual == expected @@ -197,7 +197,7 @@ class TestCreatePrompt: [ Tool(name="foo", description="Test tool FOO", func=lambda x: x), Tool(name="bar", description="Test tool BAR", func=lambda x: x), - ] + ], ) actual = prompt.messages[0].prompt.format() @@ -243,7 +243,7 @@ class TestCreatePrompt: Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. Thought: - """ # noqa: E501 + """, # noqa: E501 ).strip() assert actual == expected diff --git a/libs/langchain/tests/unit_tests/callbacks/tracers/test_logging.py b/libs/langchain/tests/unit_tests/callbacks/tracers/test_logging.py index 2679aef5c5b..db12393d096 100644 --- a/libs/langchain/tests/unit_tests/callbacks/tracers/test_logging.py +++ b/libs/langchain/tests/unit_tests/callbacks/tracers/test_logging.py @@ -8,7 +8,8 @@ from langchain.callbacks.tracers import LoggingCallbackHandler def test_logging( - caplog: pytest.LogCaptureFixture, capsys: pytest.CaptureFixture[str] + caplog: pytest.LogCaptureFixture, + capsys: pytest.CaptureFixture[str], ) -> None: # Set up a Logger and a handler so we can check the Logger's handlers work too logger = logging.getLogger("test_logging") diff --git a/libs/langchain/tests/unit_tests/chains/test_base.py b/libs/langchain/tests/unit_tests/chains/test_base.py index 231df806f9e..819197e2ef2 100644 --- a/libs/langchain/tests/unit_tests/chains/test_base.py +++ b/libs/langchain/tests/unit_tests/chains/test_base.py @@ -22,7 +22,8 @@ class FakeMemory(BaseMemory): return ["baz"] def load_memory_variables( - self, inputs: Optional[dict[str, Any]] = None + self, + inputs: Optional[dict[str, Any]] = None, ) -> dict[str, str]: """Return baz variable.""" return {"baz": "foo"} diff --git a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py index 655a13445bc..fd26b36569d 100644 --- a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py +++ b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py @@ -123,7 +123,8 @@ async def test_format_doc_with_metadata() -> None: """Test format doc on a valid document.""" doc = Document(page_content="foo", metadata={"bar": "baz"}) prompt = PromptTemplate( - input_variables=["page_content", "bar"], template="{page_content}, {bar}" + input_variables=["page_content", "bar"], + template="{page_content}, {bar}", ) expected_output = "foo, baz" output = format_document(doc, prompt) @@ -136,7 +137,8 @@ async def test_format_doc_missing_metadata() -> None: """Test format doc on a document with missing metadata.""" doc = Document(page_content="foo") prompt = PromptTemplate( - input_variables=["page_content", "bar"], template="{page_content}, {bar}" + input_variables=["page_content", "bar"], + template="{page_content}, {bar}", ) with pytest.raises(ValueError): format_document(doc, prompt) diff --git a/libs/langchain/tests/unit_tests/chains/test_conversation.py b/libs/langchain/tests/unit_tests/chains/test_conversation.py index 87bd8df3e81..2dabee35831 100644 --- a/libs/langchain/tests/unit_tests/chains/test_conversation.py +++ b/libs/langchain/tests/unit_tests/chains/test_conversation.py @@ -54,7 +54,7 @@ async def test_memory_async() -> None: memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant") await memory.asave_context({"input": "bar"}, {"output": "foo"}) assert await memory.aload_memory_variables({}) == { - "foo": "Human: bar\nAssistant: foo" + "foo": "Human: bar\nAssistant: foo", } diff --git a/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py b/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py index f0e2d489f9e..905661165b2 100644 --- a/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py +++ b/libs/langchain/tests/unit_tests/chains/test_conversation_retrieval.py @@ -16,7 +16,10 @@ async def test_simplea() -> None: llm = FakeListLLM(responses=[answer]) retriever = SequentialRetriever(sequential_responses=[[]]) memory = ConversationBufferMemory( # type: ignore[call-arg] - k=1, output_key="answer", memory_key="chat_history", return_messages=True + k=1, + output_key="answer", + memory_key="chat_history", + return_messages=True, ) qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, @@ -37,10 +40,13 @@ async def test_fixed_message_response_when_docs_founda() -> None: answer = "I know the answer!" llm = FakeListLLM(responses=[answer]) retriever = SequentialRetriever( - sequential_responses=[[Document(page_content=answer)]] + sequential_responses=[[Document(page_content=answer)]], ) memory = ConversationBufferMemory( # type: ignore[call-arg] - k=1, output_key="answer", memory_key="chat_history", return_messages=True + k=1, + output_key="answer", + memory_key="chat_history", + return_messages=True, ) qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, @@ -62,7 +68,10 @@ def test_fixed_message_response_when_no_docs_found() -> None: llm = FakeListLLM(responses=[answer]) retriever = SequentialRetriever(sequential_responses=[[]]) memory = ConversationBufferMemory( # type: ignore[call-arg] - k=1, output_key="answer", memory_key="chat_history", return_messages=True + k=1, + output_key="answer", + memory_key="chat_history", + return_messages=True, ) qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, @@ -83,10 +92,13 @@ def test_fixed_message_response_when_docs_found() -> None: answer = "I know the answer!" llm = FakeListLLM(responses=[answer]) retriever = SequentialRetriever( - sequential_responses=[[Document(page_content=answer)]] + sequential_responses=[[Document(page_content=answer)]], ) memory = ConversationBufferMemory( # type: ignore[call-arg] - k=1, output_key="answer", memory_key="chat_history", return_messages=True + k=1, + output_key="answer", + memory_key="chat_history", + return_messages=True, ) qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, diff --git a/libs/langchain/tests/unit_tests/chains/test_history_aware_retriever.py b/libs/langchain/tests/unit_tests/chains/test_history_aware_retriever.py index 78d8ab326c2..0955d66b178 100644 --- a/libs/langchain/tests/unit_tests/chains/test_history_aware_retriever.py +++ b/libs/langchain/tests/unit_tests/chains/test_history_aware_retriever.py @@ -24,6 +24,6 @@ def test_create() -> None: { "input": "What is the answer?", "chat_history": ["hi", "hi"], - } + }, ) assert output == expected_output diff --git a/libs/langchain/tests/unit_tests/chains/test_hyde.py b/libs/langchain/tests/unit_tests/chains/test_hyde.py index 990b9b4226b..52e210a89a3 100644 --- a/libs/langchain/tests/unit_tests/chains/test_hyde.py +++ b/libs/langchain/tests/unit_tests/chains/test_hyde.py @@ -64,7 +64,9 @@ def test_hyde_from_llm() -> None: """Test loading HyDE from all prompts.""" for key in PROMPT_MAP: embedding = HypotheticalDocumentEmbedder.from_llm( - FakeLLM(), FakeEmbeddings(), key + FakeLLM(), + FakeEmbeddings(), + key, ) embedding.embed_query("foo") @@ -73,6 +75,8 @@ def test_hyde_from_llm_with_multiple_n() -> None: """Test loading HyDE from all prompts.""" for key in PROMPT_MAP: embedding = HypotheticalDocumentEmbedder.from_llm( - FakeLLM(n=8), FakeEmbeddings(), key + FakeLLM(n=8), + FakeEmbeddings(), + key, ) embedding.embed_query("foo") diff --git a/libs/langchain/tests/unit_tests/chains/test_llm_checker.py b/libs/langchain/tests/unit_tests/chains/test_llm_checker.py index 5d1ff95939f..f2d33002813 100644 --- a/libs/langchain/tests/unit_tests/chains/test_llm_checker.py +++ b/libs/langchain/tests/unit_tests/chains/test_llm_checker.py @@ -19,7 +19,7 @@ def fake_llm_checker_chain() -> LLMCheckerChain: """Fake LLMCheckerChain for testing.""" queries = { _CREATE_DRAFT_ANSWER_TEMPLATE.format( - question="Which mammal lays the biggest eggs?" + question="Which mammal lays the biggest eggs?", ): "I don't know which mammal layers the biggest eggs.", _LIST_ASSERTIONS_TEMPLATE.format( statement="I don't know which mammal layers the biggest eggs.", diff --git a/libs/langchain/tests/unit_tests/chains/test_llm_summarization_checker.py b/libs/langchain/tests/unit_tests/chains/test_llm_summarization_checker.py index ff1b457cd20..7f9401b3961 100644 --- a/libs/langchain/tests/unit_tests/chains/test_llm_summarization_checker.py +++ b/libs/langchain/tests/unit_tests/chains/test_llm_summarization_checker.py @@ -1,5 +1,3 @@ -# flake8: noqa E501 - """Test LLMSummarization functionality.""" import pytest diff --git a/libs/langchain/tests/unit_tests/chains/test_memory.py b/libs/langchain/tests/unit_tests/chains/test_memory.py index 2959f15ae80..d62e9007e60 100644 --- a/libs/langchain/tests/unit_tests/chains/test_memory.py +++ b/libs/langchain/tests/unit_tests/chains/test_memory.py @@ -33,5 +33,5 @@ def test_readonly_memory(memory: BaseMemory) -> None: memory.save_context({"input": "bar"}, {"output": "foo"}) assert read_only_memory.load_memory_variables({}) == memory.load_memory_variables( - {} + {}, ) diff --git a/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py b/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py index 825284e2949..9788ac900a8 100644 --- a/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py +++ b/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py @@ -84,7 +84,9 @@ from tests.unit_tests.llms.fake_llm import FakeLLM ], ) def test_spliting_answer_into_answer_and_sources( - text: str, answer: str, sources: str + text: str, + answer: str, + sources: str, ) -> None: qa_chain = QAWithSourcesChain.from_llm(FakeLLM()) generated_answer, generated_sources = qa_chain._split_sources(text) diff --git a/libs/langchain/tests/unit_tests/chains/test_sequential.py b/libs/langchain/tests/unit_tests/chains/test_sequential.py index 02a9b274a40..fe5aa283b97 100644 --- a/libs/langchain/tests/unit_tests/chains/test_sequential.py +++ b/libs/langchain/tests/unit_tests/chains/test_sequential.py @@ -84,7 +84,9 @@ def test_sequential_usage_memory() -> None: chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) chain = SequentialChain( # type: ignore[call-arg] - memory=memory, chains=[chain_1, chain_2], input_variables=["foo"] + memory=memory, + chains=[chain_1, chain_2], + input_variables=["foo"], ) output = chain({"foo": "123"}) expected_output = {"baz": "123foofoo", "foo": "123", "zab": "rab"} @@ -94,7 +96,9 @@ def test_sequential_usage_memory() -> None: chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) with pytest.raises(ValueError): SequentialChain( # type: ignore[call-arg] - memory=memory, chains=[chain_1, chain_2], input_variables=["foo"] + memory=memory, + chains=[chain_1, chain_2], + input_variables=["foo"], ) @@ -103,7 +107,9 @@ def test_sequential_internal_chain_use_memory() -> None: memory = ConversationBufferMemory(memory_key="bla") memory.save_context({"input": "yo"}, {"output": "ya"}) chain_1 = FakeChain( - input_variables=["foo", "bla"], output_variables=["bar"], memory=memory + input_variables=["foo", "bla"], + output_variables=["bar"], + memory=memory, ) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) # type: ignore[call-arg] @@ -188,13 +194,19 @@ async def test_simple_sequential_functionality_with_callbacks(*, isAsync: bool) handler_2 = FakeCallbackHandler() handler_3 = FakeCallbackHandler() chain_1 = FakeChain( - input_variables=["foo"], output_variables=["bar"], callbacks=[handler_1] + input_variables=["foo"], + output_variables=["bar"], + callbacks=[handler_1], ) chain_2 = FakeChain( - input_variables=["bar"], output_variables=["baz"], callbacks=[handler_2] + input_variables=["bar"], + output_variables=["baz"], + callbacks=[handler_2], ) chain_3 = FakeChain( - input_variables=["jack"], output_variables=["baf"], callbacks=[handler_3] + input_variables=["jack"], + output_variables=["baf"], + callbacks=[handler_3], ) chain = SimpleSequentialChain(chains=[chain_1, chain_2, chain_3]) if isAsync: diff --git a/libs/langchain/tests/unit_tests/chains/test_summary_buffer_memory.py b/libs/langchain/tests/unit_tests/chains/test_summary_buffer_memory.py index f6651c50c23..f85bd8ad777 100644 --- a/libs/langchain/tests/unit_tests/chains/test_summary_buffer_memory.py +++ b/libs/langchain/tests/unit_tests/chains/test_summary_buffer_memory.py @@ -40,7 +40,9 @@ def test_summary_buffer_memory_summary() -> None: """Test ConversationSummaryBufferMemory when only buffer.""" llm = FakeLLM(queries={0: "summary"}, sequential_responses=True) memory = ConversationSummaryBufferMemory( - llm=llm, memory_key="baz", max_token_limit=5 + llm=llm, + memory_key="baz", + max_token_limit=5, ) memory.save_context({"input": "bar"}, {"output": "foo"}) memory.save_context({"input": "bar1"}, {"output": "foo1"}) @@ -53,7 +55,9 @@ async def test_summary_buffer_memory_summary_async() -> None: """Test ConversationSummaryBufferMemory when only buffer.""" llm = FakeLLM(queries={0: "summary"}, sequential_responses=True) memory = ConversationSummaryBufferMemory( - llm=llm, memory_key="baz", max_token_limit=5 + llm=llm, + memory_key="baz", + max_token_limit=5, ) await memory.asave_context({"input": "bar"}, {"output": "foo"}) await memory.asave_context({"input": "bar1"}, {"output": "foo1"}) diff --git a/libs/langchain/tests/unit_tests/chat_models/test_base.py b/libs/langchain/tests/unit_tests/chat_models/test_base.py index 4d314936033..e86fb892584 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_base.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_base.py @@ -40,10 +40,13 @@ def test_all_imports() -> None: ) def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None: llm1: BaseChatModel = init_chat_model( - model_name, model_provider=model_provider, api_key="foo" + model_name, + model_provider=model_provider, + api_key="foo", ) llm2: BaseChatModel = init_chat_model( - f"{model_provider}:{model_name}", api_key="foo" + f"{model_provider}:{model_name}", + api_key="foo", ) assert llm1.dict() == llm2.dict() @@ -60,7 +63,9 @@ def test_init_unknown_provider() -> None: @pytest.mark.requires("langchain_openai") @mock.patch.dict( - os.environ, {"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "bar"}, clear=True + os.environ, + {"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "bar"}, + clear=True, ) def test_configurable() -> None: model = init_chat_model() @@ -85,7 +90,7 @@ def test_configurable() -> None: # Can call declarative methods even without a default model. model_with_tools = model.bind_tools( - [{"name": "foo", "description": "foo", "parameters": {}}] + [{"name": "foo", "description": "foo", "parameters": {}}], ) # Check that original model wasn't mutated by declarative operation. @@ -93,7 +98,8 @@ def test_configurable() -> None: # Can iteratively call declarative methods. model_with_config = model_with_tools.with_config( - RunnableConfig(tags=["foo"]), configurable={"model": "gpt-4o"} + RunnableConfig(tags=["foo"]), + configurable={"model": "gpt-4o"}, ) assert model_with_config.model_name == "gpt-4o" # type: ignore[attr-defined] @@ -147,8 +153,8 @@ def test_configurable() -> None: { "type": "function", "function": {"name": "foo", "description": "foo", "parameters": {}}, - } - ] + }, + ], }, "config": {"tags": ["foo"], "configurable": {}}, "config_factories": [], @@ -159,7 +165,9 @@ def test_configurable() -> None: @pytest.mark.requires("langchain_openai", "langchain_anthropic") @mock.patch.dict( - os.environ, {"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "bar"}, clear=True + os.environ, + {"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "bar"}, + clear=True, ) def test_configurable_with_default() -> None: model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar") @@ -183,7 +191,7 @@ def test_configurable_with_default() -> None: assert model.model_name == "gpt-4o" model_with_tools = model.bind_tools( - [{"name": "foo", "description": "foo", "parameters": {}}] + [{"name": "foo", "description": "foo", "parameters": {}}], ) model_with_config = model_with_tools.with_config( @@ -217,7 +225,7 @@ def test_configurable_with_default() -> None: "stream_usage": True, }, "kwargs": { - "tools": [{"name": "foo", "description": "foo", "input_schema": {}}] + "tools": [{"name": "foo", "description": "foo", "input_schema": {}}], }, "config": {"tags": ["foo"], "configurable": {}}, "config_factories": [], diff --git a/libs/langchain/tests/unit_tests/conftest.py b/libs/langchain/tests/unit_tests/conftest.py index e9bf5a26518..2b3977b8b47 100644 --- a/libs/langchain/tests/unit_tests/conftest.py +++ b/libs/langchain/tests/unit_tests/conftest.py @@ -12,7 +12,8 @@ from pytest import Config, Function, Parser def blockbuster() -> Iterator[None]: with blockbuster_ctx("langchain") as bb: bb.functions["io.TextIOWrapper.read"].can_block_in( - "langchain/__init__.py", "" + "langchain/__init__.py", + "", ) for func in ["os.stat", "os.path.abspath"]: @@ -20,18 +21,21 @@ def blockbuster() -> Iterator[None]: bb.functions[func] .can_block_in("langchain_core/runnables/base.py", "__repr__") .can_block_in( - "langchain_core/beta/runnables/context.py", "aconfig_with_context" + "langchain_core/beta/runnables/context.py", + "aconfig_with_context", ) ) for func in ["os.stat", "io.TextIOWrapper.read"]: bb.functions[func].can_block_in( - "langsmith/client.py", "_default_retry_config" + "langsmith/client.py", + "_default_retry_config", ) for bb_function in bb.functions.values(): bb_function.can_block_in( - "freezegun/api.py", "_get_cached_module_attributes" + "freezegun/api.py", + "_get_cached_module_attributes", ) yield @@ -122,11 +126,11 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> # If the package is not installed, we immediately break # and mark the test as skipped. item.add_marker( - pytest.mark.skip(reason=f"Requires pkg: `{pkg}`") + pytest.mark.skip(reason=f"Requires pkg: `{pkg}`"), ) break else: if only_extended: item.add_marker( - pytest.mark.skip(reason="Skipping not an extended test.") + pytest.mark.skip(reason="Skipping not an extended test."), ) diff --git a/libs/langchain/tests/unit_tests/embeddings/test_base.py b/libs/langchain/tests/unit_tests/embeddings/test_base.py index 5ca91949745..5df628ad19b 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_base.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_base.py @@ -40,7 +40,8 @@ def test_parse_model_string_errors() -> None: _parse_model_string("openai:") with pytest.raises( - ValueError, match="Provider 'invalid-provider' is not supported" + ValueError, + match="Provider 'invalid-provider' is not supported", ): _parse_model_string("invalid-provider:model-name") @@ -57,11 +58,13 @@ def test_infer_model_and_provider() -> None: ) assert _infer_model_and_provider( - model="text-embedding-3-small", provider="openai" + model="text-embedding-3-small", + provider="openai", ) == ("openai", "text-embedding-3-small") assert _infer_model_and_provider( - model="ft:text-embedding-3-small", provider="openai" + model="ft:text-embedding-3-small", + provider="openai", ) == ("openai", "ft:text-embedding-3-small") assert _infer_model_and_provider(model="openai:ft:text-embedding-3-small") == ( diff --git a/libs/langchain/tests/unit_tests/embeddings/test_caching.py b/libs/langchain/tests/unit_tests/embeddings/test_caching.py index 1d06e71efb2..b4ab2c93f3d 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_caching.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_caching.py @@ -34,7 +34,9 @@ def cache_embeddings() -> CacheBackedEmbeddings: store = InMemoryStore() embeddings = MockEmbeddings() return CacheBackedEmbeddings.from_bytes_store( - embeddings, store, namespace="test_namespace" + embeddings, + store, + namespace="test_namespace", ) @@ -44,7 +46,10 @@ def cache_embeddings_batch() -> CacheBackedEmbeddings: store = InMemoryStore() embeddings = MockEmbeddings() return CacheBackedEmbeddings.from_bytes_store( - embeddings, store, namespace="test_namespace", batch_size=3 + embeddings, + store, + namespace="test_namespace", + batch_size=3, ) @@ -155,7 +160,10 @@ def test_blake2b_encoder() -> None: store = InMemoryStore() emb = MockEmbeddings() cbe = CacheBackedEmbeddings.from_bytes_store( - emb, store, namespace="ns_", key_encoder="blake2b" + emb, + store, + namespace="ns_", + key_encoder="blake2b", ) text = "blake" @@ -171,7 +179,10 @@ def test_sha256_encoder() -> None: store = InMemoryStore() emb = MockEmbeddings() cbe = CacheBackedEmbeddings.from_bytes_store( - emb, store, namespace="ns_", key_encoder="sha256" + emb, + store, + namespace="ns_", + key_encoder="sha256", ) text = "foo" @@ -187,7 +198,10 @@ def test_sha512_encoder() -> None: store = InMemoryStore() emb = MockEmbeddings() cbe = CacheBackedEmbeddings.from_bytes_store( - emb, store, namespace="ns_", key_encoder="sha512" + emb, + store, + namespace="ns_", + key_encoder="sha512", ) text = "foo" diff --git a/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py b/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py index 7395036f8fe..50deade41d3 100644 --- a/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py +++ b/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py @@ -77,7 +77,7 @@ but otherwise poor performance, we give the model a score of 2.""", with pytest.raises(OutputParserException): trajectory_output_parser.parse( """Judgment: Given the good reasoning in the final answer -but otherwise poor performance, we give the model a score of 2.""" +but otherwise poor performance, we give the model a score of 2.""", ) with pytest.raises(OutputParserException): @@ -85,7 +85,7 @@ but otherwise poor performance, we give the model a score of 2.""" """Judgment: Given the good reasoning in the final answer but otherwise poor performance, we give the model a score of 2. -Score: 9""" +Score: 9""", ) with pytest.raises(OutputParserException): @@ -93,7 +93,7 @@ Score: 9""" """Judgment: Given the good reasoning in the final answer but otherwise poor performance, we give the model a score of 2. -Score: 10""" +Score: 10""", ) with pytest.raises(OutputParserException): @@ -101,7 +101,7 @@ Score: 10""" """Judgment: Given the good reasoning in the final answer but otherwise poor performance, we give the model a score of 2. -Score: 0.1""" +Score: 0.1""", ) with pytest.raises(OutputParserException): @@ -109,7 +109,7 @@ Score: 0.1""" """Judgment: Given the good reasoning in the final answer but otherwise poor performance, we give the model a score of 2. -Score: One""" +Score: One""", ) @@ -181,7 +181,7 @@ def test_old_api_works(intermediate_steps: list[tuple[AgentAction, str]]) -> Non "question": "What is your favorite food?", "agent_trajectory": intermediate_steps, "answer": "I like pie.", - } + }, ) assert res["score"] == 1.0 @@ -191,6 +191,6 @@ def test_old_api_works(intermediate_steps: list[tuple[AgentAction, str]]) -> Non "agent_trajectory": intermediate_steps, "answer": "I like pie.", "reference": "Paris", - } + }, ) assert res["score"] == 0.0 diff --git a/libs/langchain/tests/unit_tests/evaluation/criteria/test_eval_chain.py b/libs/langchain/tests/unit_tests/evaluation/criteria/test_eval_chain.py index a58e53d97df..fa6605b8eed 100644 --- a/libs/langchain/tests/unit_tests/evaluation/criteria/test_eval_chain.py +++ b/libs/langchain/tests/unit_tests/evaluation/criteria/test_eval_chain.py @@ -15,10 +15,10 @@ from tests.unit_tests.llms.fake_llm import FakeLLM def test_resolve_criteria_str() -> None: assert CriteriaEvalChain.resolve_criteria("helpfulness") == { - "helpfulness": _SUPPORTED_CRITERIA[Criteria.HELPFULNESS] + "helpfulness": _SUPPORTED_CRITERIA[Criteria.HELPFULNESS], } assert CriteriaEvalChain.resolve_criteria("correctness") == { - "correctness": _SUPPORTED_CRITERIA[Criteria.CORRECTNESS] + "correctness": _SUPPORTED_CRITERIA[Criteria.CORRECTNESS], } @@ -62,20 +62,23 @@ def test_CriteriaResultOutputParser_parse(text: str, want: dict) -> None: @pytest.mark.parametrize("criterion", list(Criteria)) def test_resolve_criteria_enum(criterion: Criteria) -> None: assert CriteriaEvalChain.resolve_criteria(criterion) == { - criterion.value: _SUPPORTED_CRITERIA[criterion] + criterion.value: _SUPPORTED_CRITERIA[criterion], } def test_criteria_eval_chain() -> None: chain = CriteriaEvalChain.from_llm( llm=FakeLLM( - queries={"text": "The meaning of life\nY"}, sequential_responses=True + queries={"text": "The meaning of life\nY"}, + sequential_responses=True, ), criteria={"my criterion": "my criterion description"}, ) with pytest.warns(UserWarning, match=chain._skip_reference_warning): result = chain.evaluate_strings( - prediction="my prediction", reference="my reference", input="my input" + prediction="my prediction", + reference="my reference", + input="my input", ) assert result["reasoning"] == "The meaning of life" diff --git a/libs/langchain/tests/unit_tests/evaluation/exact_match/test_base.py b/libs/langchain/tests/unit_tests/evaluation/exact_match/test_base.py index 00fc7b25a09..e9623d75ea6 100644 --- a/libs/langchain/tests/unit_tests/evaluation/exact_match/test_base.py +++ b/libs/langchain/tests/unit_tests/evaluation/exact_match/test_base.py @@ -21,13 +21,15 @@ def test_default_exact_matching( prediction = "Mindy is the CTO" reference = "Mindy is the CTO" result = exact_match_string_evaluator.evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result["score"] == 1.0 reference = "Mindy is the CEO" result = exact_match_string_evaluator.evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result["score"] == 0.0 @@ -38,12 +40,14 @@ def test_exact_matching_with_ignore_case( prediction = "Mindy is the CTO" reference = "mindy is the cto" result = exact_match_string_evaluator_ignore_case.evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result["score"] == 1.0 reference = "mindy is the CEO" result = exact_match_string_evaluator_ignore_case.evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result["score"] == 0.0 diff --git a/libs/langchain/tests/unit_tests/evaluation/parsing/test_base.py b/libs/langchain/tests/unit_tests/evaluation/parsing/test_base.py index 82fe26e30ff..d414e12c8a5 100644 --- a/libs/langchain/tests/unit_tests/evaluation/parsing/test_base.py +++ b/libs/langchain/tests/unit_tests/evaluation/parsing/test_base.py @@ -84,7 +84,8 @@ def test_json_equality_evaluator_evaluate_strings_equal( prediction = '{"a": 1}' reference = '{"a": 1}' result = json_equality_evaluator.evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result == {"score": True} @@ -95,7 +96,8 @@ def test_json_equality_evaluator_evaluate_strings_not_equal( prediction = '{"a": 1}' reference = '{"a": 2}' result = json_equality_evaluator.evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result == {"score": False} @@ -166,7 +168,7 @@ def test_json_equality_evaluator_evaluate_lists_permutation_invariant() -> None: "[" + ",".join( [f'{{"a": {i + 1}, "b": {i + 2}}}' for i in range(999)] - + ['{"a": 1000, "b": 1001}'] + + ['{"a": 1000, "b": 1001}'], ) + "]" ) diff --git a/libs/langchain/tests/unit_tests/evaluation/parsing/test_json_distance.py b/libs/langchain/tests/unit_tests/evaluation/parsing/test_json_distance.py index e0bee70ddca..fc3596422a2 100644 --- a/libs/langchain/tests/unit_tests/evaluation/parsing/test_json_distance.py +++ b/libs/langchain/tests/unit_tests/evaluation/parsing/test_json_distance.py @@ -45,7 +45,8 @@ def test_json_distance_evaluator_evaluate_strings_simple_diff( prediction = '{"a": 1}' reference = '{"a": 2}' result = json_distance_evaluator._evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) # Only 1 character flipped pytest.approx(1 / 7, result["score"]) @@ -58,7 +59,8 @@ def test_json_distance_evaluator_evaluate_strings_complex_diff( prediction = '{"a":1, "b": {"c": 2, "d": 3}}' reference = '{"a": 1, "b": {"c": 2, "d": 4}}' result = json_distance_evaluator._evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) # Only 1 character flipped pytest.approx(1 / len(reference.replace(" ", "")), result["score"]) @@ -71,7 +73,8 @@ def test_json_distance_evaluator_evaluate_strings_list_diff( prediction = '[{"a": 1, "b": 2}, {"a": 2, "b": 3}]' reference = '[{"a": 1, "b": 2}, {"a": 2, "b": 4}]' result = json_distance_evaluator._evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) # Again only 1 character flipped pytest.approx(1 / len(reference.replace(" ", "")), result["score"]) @@ -84,7 +87,8 @@ def test_json_distance_evaluator_evaluate_strings_list_same( prediction = '[{"a": 1, "b": 2}, {"a": 2, "b": 3}]' reference = '[{"b": 2, "a": 1}, {"b": 3, "a": 2}]' result = json_distance_evaluator._evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result["score"] == 0 @@ -96,10 +100,12 @@ def test_json_distance_evaluator_evaluate_strings_list_diff_length( prediction = '[{"a": 1, "b": 2}, {"a": 2, "b": 3}]' reference = '[{"a": 1, "b": 2}]' result = json_distance_evaluator._evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) pytest.approx( - len('{"a":2,"b":3}') / len(reference.replace(" ", "")), result["score"] + len('{"a":2,"b":3}') / len(reference.replace(" ", "")), + result["score"], ) diff --git a/libs/langchain/tests/unit_tests/evaluation/parsing/test_json_schema.py b/libs/langchain/tests/unit_tests/evaluation/parsing/test_json_schema.py index d200c358270..7847deed7db 100644 --- a/libs/langchain/tests/unit_tests/evaluation/parsing/test_json_schema.py +++ b/libs/langchain/tests/unit_tests/evaluation/parsing/test_json_schema.py @@ -39,7 +39,8 @@ def test_json_schema_evaluator_valid_prediction( "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, } result = json_schema_evaluator._evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result["score"] is True @@ -54,7 +55,8 @@ def test_json_schema_evaluator_invalid_prediction( "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, } result = json_schema_evaluator._evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result["score"] is False assert "reasoning" in result @@ -71,7 +73,8 @@ def test_json_schema_evaluator_missing_property( "required": ["name", "age"], } result = json_schema_evaluator._evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result["score"] is False assert "reasoning" in result diff --git a/libs/langchain/tests/unit_tests/evaluation/qa/test_eval_chain.py b/libs/langchain/tests/unit_tests/evaluation/qa/test_eval_chain.py index 72ce29c31dd..5d9c5cda163 100644 --- a/libs/langchain/tests/unit_tests/evaluation/qa/test_eval_chain.py +++ b/libs/langchain/tests/unit_tests/evaluation/qa/test_eval_chain.py @@ -19,7 +19,8 @@ from tests.unit_tests.llms.fake_llm import FakeLLM @pytest.mark.skipif( - sys.platform.startswith("win"), reason="Test not supported on Windows" + sys.platform.startswith("win"), + reason="Test not supported on Windows", ) def test_eval_chain() -> None: """Test a simple eval chain.""" @@ -34,7 +35,8 @@ def test_eval_chain() -> None: @pytest.mark.skipif( - sys.platform.startswith("win"), reason="Test not supported on Windows" + sys.platform.startswith("win"), + reason="Test not supported on Windows", ) @pytest.mark.parametrize("chain_cls", [ContextQAEvalChain, CotQAEvalChain]) def test_context_eval_chain(chain_cls: type[ContextQAEvalChain]) -> None: @@ -76,11 +78,14 @@ def test_returns_expected_results( chain_cls: type[LLMChain], ) -> None: fake_llm = FakeLLM( - queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True + queries={"text": "The meaning of life\nCORRECT"}, + sequential_responses=True, ) chain = chain_cls.from_llm(fake_llm) # type: ignore[attr-defined] results = chain.evaluate_strings( - prediction="my prediction", reference="my reference", input="my input" + prediction="my prediction", + reference="my reference", + input="my input", ) assert results["score"] == 1 diff --git a/libs/langchain/tests/unit_tests/evaluation/regex_match/test_base.py b/libs/langchain/tests/unit_tests/evaluation/regex_match/test_base.py index 4b725d7ddeb..ed6ab865f23 100644 --- a/libs/langchain/tests/unit_tests/evaluation/regex_match/test_base.py +++ b/libs/langchain/tests/unit_tests/evaluation/regex_match/test_base.py @@ -23,13 +23,15 @@ def test_default_regex_matching( prediction = "Mindy is the CTO" reference = "^Mindy.*CTO$" result = regex_match_string_evaluator.evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result["score"] == 1.0 reference = "^Mike.*CEO$" result = regex_match_string_evaluator.evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result["score"] == 0.0 @@ -40,6 +42,7 @@ def test_regex_matching_with_ignore_case( prediction = "Mindy is the CTO" reference = "^mindy.*cto$" result = regex_match_string_evaluator_ignore_case.evaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert result["score"] == 1.0 diff --git a/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py b/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py index d31180ad462..abd38bdfac5 100644 --- a/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py +++ b/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py @@ -31,10 +31,13 @@ async def test_zero_distance_async(distance: StringDistance) -> None: @pytest.mark.parametrize("distance", list(StringDistance)) @pytest.mark.parametrize("normalize_score", [True, False]) def test_zero_distance_pairwise( - *, distance: StringDistance, normalize_score: bool + *, + distance: StringDistance, + normalize_score: bool, ) -> None: eval_chain = PairwiseStringDistanceEvalChain( - distance=distance, normalize_score=normalize_score + distance=distance, + normalize_score=normalize_score, ) string = "三人行则必有我师" result = eval_chain.evaluate_string_pairs(prediction=string, prediction_b=string) @@ -48,7 +51,8 @@ async def test_zero_distance_pairwise_async(distance: StringDistance) -> None: eval_chain = PairwiseStringDistanceEvalChain(distance=distance) string = "三人行则必有我师" result = await eval_chain.aevaluate_string_pairs( - prediction=string, prediction_b=string + prediction=string, + prediction_b=string, ) assert "score" in result assert result["score"] == 0 @@ -64,7 +68,8 @@ valid_distances = [ @pytest.mark.parametrize("normalize_score", [True, False]) def test_non_zero_distance(*, distance: StringDistance, normalize_score: bool) -> None: eval_chain = StringDistanceEvalChain( - distance=distance, normalize_score=normalize_score + distance=distance, + normalize_score=normalize_score, ) prediction = "I like to eat apples." reference = "I like apples." @@ -82,7 +87,8 @@ async def test_non_zero_distance_async(distance: StringDistance) -> None: prediction = "I like to eat apples." reference = "I like apples." result = await eval_chain.aevaluate_strings( - prediction=prediction, reference=reference + prediction=prediction, + reference=reference, ) assert "score" in result assert 0 < result["score"] < 1.0 @@ -95,7 +101,8 @@ def test_non_zero_distance_pairwise(distance: StringDistance) -> None: prediction = "I like to eat apples." reference = "I like apples." result = eval_chain.evaluate_string_pairs( - prediction=prediction, prediction_b=reference + prediction=prediction, + prediction_b=reference, ) assert "score" in result assert 0 < result["score"] < 1.0 @@ -108,7 +115,8 @@ async def test_non_zero_distance_pairwise_async(distance: StringDistance) -> Non prediction = "I like to eat apples." reference = "I like apples." result = await eval_chain.aevaluate_string_pairs( - prediction=prediction, prediction_b=reference + prediction=prediction, + prediction_b=reference, ) assert "score" in result assert 0 < result["score"] < 1.0 diff --git a/libs/langchain/tests/unit_tests/indexes/test_indexing.py b/libs/langchain/tests/unit_tests/indexes/test_indexing.py index 47457615fe8..c4e8b8557e3 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_indexing.py +++ b/libs/langchain/tests/unit_tests/indexes/test_indexing.py @@ -110,7 +110,7 @@ class InMemoryVectorStore(VectorStore): **kwargs: Any, ) -> list[str]: """Add the given texts to the store (insert behavior).""" - raise NotImplementedError() + raise NotImplementedError @classmethod def from_texts( @@ -121,13 +121,16 @@ class InMemoryVectorStore(VectorStore): **kwargs: Any, ) -> VST: """Create a vector store from a list of texts.""" - raise NotImplementedError() + raise NotImplementedError def similarity_search( - self, query: str, k: int = 4, **kwargs: Any + self, + query: str, + k: int = 4, + **kwargs: Any, ) -> list[Document]: """Find the most similar documents to the given query.""" - raise NotImplementedError() + raise NotImplementedError @pytest.fixture @@ -164,7 +167,8 @@ def upserting_vector_store() -> InMemoryVectorStore: def test_indexing_same_content( - record_manager: SQLRecordManager, vector_store: InMemoryVectorStore + record_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Indexing some content to confirm it gets added only once.""" loader = ToyLoader( @@ -175,7 +179,7 @@ def test_indexing_same_content( Document( page_content="This is another document.", ), - ] + ], ) assert index(loader, record_manager, vector_store) == { @@ -199,7 +203,8 @@ def test_indexing_same_content( @pytest.mark.requires("aiosqlite") async def test_aindexing_same_content( - arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore + arecord_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Indexing some content to confirm it gets added only once.""" loader = ToyLoader( @@ -210,7 +215,7 @@ async def test_aindexing_same_content( Document( page_content="This is another document.", ), - ] + ], ) assert await aindex(loader, arecord_manager, vector_store) == { @@ -233,7 +238,8 @@ async def test_aindexing_same_content( def test_index_simple_delete_full( - record_manager: SQLRecordManager, vector_store: InMemoryVectorStore + record_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Indexing some content to confirm it gets added only once.""" loader = ToyLoader( @@ -244,11 +250,13 @@ def test_index_simple_delete_full( Document( page_content="This is another document.", ), - ] + ], ) with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 1).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 1).timestamp(), ): assert index(loader, record_manager, vector_store, cleanup="full") == { "num_added": 2, @@ -258,7 +266,9 @@ def test_index_simple_delete_full( } with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 1).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 1).timestamp(), ): assert index(loader, record_manager, vector_store, cleanup="full") == { "num_added": 0, @@ -275,11 +285,13 @@ def test_index_simple_delete_full( Document( page_content="This is another document.", # <-- Same as original ), - ] + ], ) with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert index(loader, record_manager, vector_store, cleanup="full") == { "num_added": 1, @@ -297,7 +309,9 @@ def test_index_simple_delete_full( # Attempt to index again verify that nothing changes with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert index(loader, record_manager, vector_store, cleanup="full") == { "num_added": 0, @@ -309,7 +323,8 @@ def test_index_simple_delete_full( @pytest.mark.requires("aiosqlite") async def test_aindex_simple_delete_full( - arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore + arecord_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Indexing some content to confirm it gets added only once.""" loader = ToyLoader( @@ -320,11 +335,13 @@ async def test_aindex_simple_delete_full( Document( page_content="This is another document.", ), - ] + ], ) with patch.object( - arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp() + arecord_manager, + "aget_time", + return_value=datetime(2021, 1, 1).timestamp(), ): assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { "num_added": 2, @@ -334,7 +351,9 @@ async def test_aindex_simple_delete_full( } with patch.object( - arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp() + arecord_manager, + "aget_time", + return_value=datetime(2021, 1, 1).timestamp(), ): assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { "num_added": 0, @@ -351,11 +370,13 @@ async def test_aindex_simple_delete_full( Document( page_content="This is another document.", # <-- Same as original ), - ] + ], ) with patch.object( - arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + arecord_manager, + "aget_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { "num_added": 1, @@ -373,7 +394,9 @@ async def test_aindex_simple_delete_full( # Attempt to index again verify that nothing changes with patch.object( - arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + arecord_manager, + "aget_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { "num_added": 0, @@ -384,7 +407,8 @@ async def test_aindex_simple_delete_full( def test_incremental_fails_with_bad_source_ids( - record_manager: SQLRecordManager, vector_store: InMemoryVectorStore + record_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Test indexing with incremental deletion strategy.""" loader = ToyLoader( @@ -401,7 +425,7 @@ def test_incremental_fails_with_bad_source_ids( page_content="This is yet another document.", metadata={"source": None}, ), - ] + ], ) with pytest.raises(ValueError): @@ -421,7 +445,8 @@ def test_incremental_fails_with_bad_source_ids( @pytest.mark.requires("aiosqlite") async def test_aincremental_fails_with_bad_source_ids( - arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore + arecord_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Test indexing with incremental deletion strategy.""" loader = ToyLoader( @@ -438,7 +463,7 @@ async def test_aincremental_fails_with_bad_source_ids( page_content="This is yet another document.", metadata={"source": None}, ), - ] + ], ) with pytest.raises(ValueError): @@ -462,7 +487,8 @@ async def test_aincremental_fails_with_bad_source_ids( def test_no_delete( - record_manager: SQLRecordManager, vector_store: InMemoryVectorStore + record_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Test indexing without a deletion strategy.""" loader = ToyLoader( @@ -475,11 +501,13 @@ def test_no_delete( page_content="This is another document.", metadata={"source": "2"}, ), - ] + ], ) with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert index( loader, @@ -496,7 +524,9 @@ def test_no_delete( # If we add the same content twice it should be skipped with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert index( loader, @@ -521,12 +551,14 @@ def test_no_delete( page_content="This is another document.", metadata={"source": "2"}, ), - ] + ], ) # Should result in no updates or deletions! with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert index( loader, @@ -544,7 +576,8 @@ def test_no_delete( @pytest.mark.requires("aiosqlite") async def test_ano_delete( - arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore + arecord_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Test indexing without a deletion strategy.""" loader = ToyLoader( @@ -557,11 +590,13 @@ async def test_ano_delete( page_content="This is another document.", metadata={"source": "2"}, ), - ] + ], ) with patch.object( - arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + arecord_manager, + "aget_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert await aindex( loader, @@ -578,7 +613,9 @@ async def test_ano_delete( # If we add the same content twice it should be skipped with patch.object( - arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + arecord_manager, + "aget_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert await aindex( loader, @@ -603,12 +640,14 @@ async def test_ano_delete( page_content="This is another document.", metadata={"source": "2"}, ), - ] + ], ) # Should result in no updates or deletions! with patch.object( - arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + arecord_manager, + "aget_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert await aindex( loader, @@ -625,7 +664,8 @@ async def test_ano_delete( def test_incremental_delete( - record_manager: SQLRecordManager, vector_store: InMemoryVectorStore + record_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Test indexing with incremental deletion strategy.""" loader = ToyLoader( @@ -638,11 +678,13 @@ def test_incremental_delete( page_content="This is another document.", metadata={"source": "2"}, ), - ] + ], ) with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert index( loader, @@ -666,7 +708,9 @@ def test_incremental_delete( # Attempt to index again verify that nothing changes with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert index( loader, @@ -696,12 +740,14 @@ def test_incremental_delete( page_content="This is another document.", # <-- Same as original metadata={"source": "2"}, ), - ] + ], ) # Attempt to index again verify that nothing changes with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 3).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 3).timestamp(), ): assert index( loader, @@ -729,7 +775,8 @@ def test_incremental_delete( def test_incremental_indexing_with_batch_size( - record_manager: SQLRecordManager, vector_store: InMemoryVectorStore + record_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Test indexing with incremental indexing""" loader = ToyLoader( @@ -750,11 +797,13 @@ def test_incremental_indexing_with_batch_size( page_content="4", metadata={"source": "1"}, ), - ] + ], ) with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert index( loader, @@ -793,7 +842,8 @@ def test_incremental_indexing_with_batch_size( def test_incremental_delete_with_batch_size( - record_manager: SQLRecordManager, vector_store: InMemoryVectorStore + record_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Test indexing with incremental deletion strategy and batch size.""" loader = ToyLoader( @@ -814,11 +864,13 @@ def test_incremental_delete_with_batch_size( page_content="4", metadata={"source": "4"}, ), - ] + ], ) with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert index( loader, @@ -843,7 +895,9 @@ def test_incremental_delete_with_batch_size( # Attempt to index again verify that nothing changes with patch.object( - record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp() + record_manager, + "get_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert index( loader, @@ -861,7 +915,9 @@ def test_incremental_delete_with_batch_size( # Attempt to index again verify that nothing changes with patch.object( - record_manager, "get_time", return_value=datetime(2022, 1, 3).timestamp() + record_manager, + "get_time", + return_value=datetime(2022, 1, 3).timestamp(), ): # Docs with same content docs = [ @@ -890,7 +946,9 @@ def test_incremental_delete_with_batch_size( # Attempt to index again verify that nothing changes with patch.object( - record_manager, "get_time", return_value=datetime(2023, 1, 3).timestamp() + record_manager, + "get_time", + return_value=datetime(2023, 1, 3).timestamp(), ): # Docs with same content docs = [ @@ -919,7 +977,9 @@ def test_incremental_delete_with_batch_size( # Try to index with changed docs now with patch.object( - record_manager, "get_time", return_value=datetime(2024, 1, 3).timestamp() + record_manager, + "get_time", + return_value=datetime(2024, 1, 3).timestamp(), ): # Docs with same content docs = [ @@ -948,7 +1008,8 @@ def test_incremental_delete_with_batch_size( @pytest.mark.requires("aiosqlite") async def test_aincremental_delete( - arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore + arecord_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Test indexing with incremental deletion strategy.""" loader = ToyLoader( @@ -961,11 +1022,13 @@ async def test_aincremental_delete( page_content="This is another document.", metadata={"source": "2"}, ), - ] + ], ) with patch.object( - arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + arecord_manager, + "aget_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert await aindex( loader.lazy_load(), @@ -989,7 +1052,9 @@ async def test_aincremental_delete( # Attempt to index again verify that nothing changes with patch.object( - arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + arecord_manager, + "aget_time", + return_value=datetime(2021, 1, 2).timestamp(), ): assert await aindex( loader.lazy_load(), @@ -1019,12 +1084,14 @@ async def test_aincremental_delete( page_content="This is another document.", # <-- Same as original metadata={"source": "2"}, ), - ] + ], ) # Attempt to index again verify that nothing changes with patch.object( - arecord_manager, "aget_time", return_value=datetime(2021, 1, 3).timestamp() + arecord_manager, + "aget_time", + return_value=datetime(2021, 1, 3).timestamp(), ): assert await aindex( loader.lazy_load(), @@ -1052,7 +1119,8 @@ async def test_aincremental_delete( def test_indexing_with_no_docs( - record_manager: SQLRecordManager, vector_store: VectorStore + record_manager: SQLRecordManager, + vector_store: VectorStore, ) -> None: """Check edge case when loader returns no new docs.""" loader = ToyLoader(documents=[]) @@ -1067,7 +1135,8 @@ def test_indexing_with_no_docs( @pytest.mark.requires("aiosqlite") async def test_aindexing_with_no_docs( - arecord_manager: SQLRecordManager, vector_store: VectorStore + arecord_manager: SQLRecordManager, + vector_store: VectorStore, ) -> None: """Check edge case when loader returns no new docs.""" loader = ToyLoader(documents=[]) @@ -1081,7 +1150,8 @@ async def test_aindexing_with_no_docs( def test_deduplication( - record_manager: SQLRecordManager, vector_store: VectorStore + record_manager: SQLRecordManager, + vector_store: VectorStore, ) -> None: """Check edge case when loader returns no new docs.""" docs = [ @@ -1106,7 +1176,8 @@ def test_deduplication( @pytest.mark.requires("aiosqlite") async def test_adeduplication( - arecord_manager: SQLRecordManager, vector_store: VectorStore + arecord_manager: SQLRecordManager, + vector_store: VectorStore, ) -> None: """Check edge case when loader returns no new docs.""" docs = [ @@ -1130,7 +1201,8 @@ async def test_adeduplication( def test_cleanup_with_different_batchsize( - record_manager: SQLRecordManager, vector_store: VectorStore + record_manager: SQLRecordManager, + vector_store: VectorStore, ) -> None: """Check that we can clean up with different batch size.""" docs = [ @@ -1157,7 +1229,11 @@ def test_cleanup_with_different_batchsize( ] assert index( - docs, record_manager, vector_store, cleanup="full", cleanup_batch_size=17 + docs, + record_manager, + vector_store, + cleanup="full", + cleanup_batch_size=17, ) == { "num_added": 1001, "num_deleted": 1000, @@ -1168,7 +1244,8 @@ def test_cleanup_with_different_batchsize( @pytest.mark.requires("aiosqlite") async def test_async_cleanup_with_different_batchsize( - arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore + arecord_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Check that we can clean up with different batch size.""" docs = [ @@ -1195,7 +1272,11 @@ async def test_async_cleanup_with_different_batchsize( ] assert await aindex( - docs, arecord_manager, vector_store, cleanup="full", cleanup_batch_size=17 + docs, + arecord_manager, + vector_store, + cleanup="full", + cleanup_batch_size=17, ) == { "num_added": 1001, "num_deleted": 1000, @@ -1205,7 +1286,8 @@ async def test_async_cleanup_with_different_batchsize( def test_deduplication_v2( - record_manager: SQLRecordManager, vector_store: VectorStore + record_manager: SQLRecordManager, + vector_store: VectorStore, ) -> None: """Check edge case when loader returns no new docs.""" docs = [ @@ -1237,7 +1319,7 @@ def test_deduplication_v2( # using in memory implementation here assert isinstance(vector_store, InMemoryVectorStore) contents = sorted( - [document.page_content for document in vector_store.store.values()] + [document.page_content for document in vector_store.store.values()], ) assert contents == ["1", "2", "3"] @@ -1268,7 +1350,8 @@ async def test_abatch() -> None: def test_indexing_force_update( - record_manager: SQLRecordManager, upserting_vector_store: VectorStore + record_manager: SQLRecordManager, + upserting_vector_store: VectorStore, ) -> None: """Test indexing with force update.""" docs = [ @@ -1301,7 +1384,11 @@ def test_indexing_force_update( } assert index( - docs, record_manager, upserting_vector_store, cleanup="full", force_update=True + docs, + record_manager, + upserting_vector_store, + cleanup="full", + force_update=True, ) == { "num_added": 0, "num_deleted": 0, @@ -1312,7 +1399,8 @@ def test_indexing_force_update( @pytest.mark.requires("aiosqlite") async def test_aindexing_force_update( - arecord_manager: SQLRecordManager, upserting_vector_store: VectorStore + arecord_manager: SQLRecordManager, + upserting_vector_store: VectorStore, ) -> None: """Test indexing with force update.""" docs = [ @@ -1331,7 +1419,10 @@ async def test_aindexing_force_update( ] assert await aindex( - docs, arecord_manager, upserting_vector_store, cleanup="full" + docs, + arecord_manager, + upserting_vector_store, + cleanup="full", ) == { "num_added": 2, "num_deleted": 0, @@ -1340,7 +1431,10 @@ async def test_aindexing_force_update( } assert await aindex( - docs, arecord_manager, upserting_vector_store, cleanup="full" + docs, + arecord_manager, + upserting_vector_store, + cleanup="full", ) == { "num_added": 0, "num_deleted": 0, @@ -1363,7 +1457,8 @@ async def test_aindexing_force_update( def test_indexing_custom_batch_size( - record_manager: SQLRecordManager, vector_store: InMemoryVectorStore + record_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Test indexing with a custom batch size.""" docs = [ @@ -1389,7 +1484,7 @@ def test_indexing_custom_batch_size( page_content="This is a test document.", metadata={"source": "1"}, id=ids[0], - ) + ), ] assert args == (docs_with_id,) assert kwargs == {"ids": ids, "batch_size": batch_size} @@ -1397,7 +1492,8 @@ def test_indexing_custom_batch_size( @pytest.mark.requires("aiosqlite") async def test_aindexing_custom_batch_size( - arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore + arecord_manager: SQLRecordManager, + vector_store: InMemoryVectorStore, ) -> None: """Test indexing with a custom batch size.""" docs = [ @@ -1423,7 +1519,7 @@ async def test_aindexing_custom_batch_size( page_content="This is a test document.", metadata={"source": "1"}, id=ids[0], - ) + ), ] assert args == (docs_with_id,) assert kwargs == {"ids": ids, "batch_size": batch_size} diff --git a/libs/langchain/tests/unit_tests/llms/fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/fake_chat_model.py index 957aaabedb0..d4d46d3b973 100644 --- a/libs/langchain/tests/unit_tests/llms/fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/fake_chat_model.py @@ -95,7 +95,10 @@ class GenericFakeChatModel(BaseChatModel): ) -> Iterator[ChatGenerationChunk]: """Stream the output of the model.""" chat_result = self._generate( - messages, stop=stop, run_manager=run_manager, **kwargs + messages, + stop=stop, + run_manager=run_manager, + **kwargs, ) if not isinstance(chat_result, ChatResult): msg = ( @@ -123,7 +126,7 @@ class GenericFakeChatModel(BaseChatModel): for token in content_chunks: chunk = ChatGenerationChunk( - message=AIMessageChunk(id=message.id, content=token) + message=AIMessageChunk(id=message.id, content=token), ) if run_manager: run_manager.on_llm_new_token(token, chunk=chunk) @@ -144,9 +147,9 @@ class GenericFakeChatModel(BaseChatModel): id=message.id, content="", additional_kwargs={ - "function_call": {fkey: fvalue_chunk} + "function_call": {fkey: fvalue_chunk}, }, - ) + ), ) if run_manager: run_manager.on_llm_new_token( @@ -160,7 +163,7 @@ class GenericFakeChatModel(BaseChatModel): id=message.id, content="", additional_kwargs={"function_call": {fkey: fvalue}}, - ) + ), ) if run_manager: run_manager.on_llm_new_token( @@ -171,8 +174,10 @@ class GenericFakeChatModel(BaseChatModel): else: chunk = ChatGenerationChunk( message=AIMessageChunk( - id=message.id, content="", additional_kwargs={key: value} - ) + id=message.id, + content="", + additional_kwargs={key: value}, + ), ) if run_manager: run_manager.on_llm_new_token( diff --git a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py index a578a38b9df..9b27f3e1196 100644 --- a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py @@ -41,7 +41,7 @@ async def test_generic_fake_chat_model_stream() -> None: infinite_cycle = cycle( [ AIMessage(content="hello goodbye"), - ] + ], ) model = GenericFakeChatModel(messages=infinite_cycle) chunks = [chunk async for chunk in model.astream("meow")] @@ -76,7 +76,7 @@ async def test_generic_fake_chat_model_stream() -> None: "name": "move_file", "arguments": '{\n "source_path": "foo",\n "' 'destination_path": "bar"\n}', - } + }, }, ) model = GenericFakeChatModel(messages=cycle([message])) @@ -92,17 +92,19 @@ async def test_generic_fake_chat_model_stream() -> None: id="a1", content="", additional_kwargs={ - "function_call": {"arguments": '{\n "source_path": "foo"'} + "function_call": {"arguments": '{\n "source_path": "foo"'}, }, ), AIMessageChunk( - id="a1", content="", additional_kwargs={"function_call": {"arguments": ","}} + id="a1", + content="", + additional_kwargs={"function_call": {"arguments": ","}}, ), AIMessageChunk( id="a1", content="", additional_kwargs={ - "function_call": {"arguments": '\n "destination_path": "bar"\n}'} + "function_call": {"arguments": '\n "destination_path": "bar"\n}'}, }, ), ] @@ -122,7 +124,7 @@ async def test_generic_fake_chat_model_stream() -> None: "name": "move_file", "arguments": '{\n "source_path": "foo",\n "' 'destination_path": "bar"\n}', - } + }, }, ) @@ -179,7 +181,7 @@ async def test_callback_handlers() -> None: infinite_cycle = cycle( [ AIMessage(content="hello goodbye"), - ] + ], ) model = GenericFakeChatModel(messages=infinite_cycle) tokens: list[str] = [] @@ -187,7 +189,8 @@ async def test_callback_handlers() -> None: results = [ chunk async for chunk in model.astream( - "meow", {"callbacks": [MyCustomAsyncHandler(tokens)]} + "meow", + {"callbacks": [MyCustomAsyncHandler(tokens)]}, ) ] assert results == [ diff --git a/libs/langchain/tests/unit_tests/load/test_dump.py b/libs/langchain/tests/unit_tests/load/test_dump.py index 767ff3f5a84..dfbfc9a8778 100644 --- a/libs/langchain/tests/unit_tests/load/test_dump.py +++ b/libs/langchain/tests/unit_tests/load/test_dump.py @@ -136,7 +136,8 @@ def test_aliases_hidden() -> None: assert dumped == expected_dump # Check while patching the os environment with patch.dict( - os.environ, {"MY_FAVORITE_SECRET": "hello", "MY_OTHER_SECRET": "world"} + os.environ, + {"MY_FAVORITE_SECRET": "hello", "MY_OTHER_SECRET": "world"}, ): test_class = TestClass() # type: ignore[call-arg] dumped = json.loads(dumps(test_class, pretty=True)) diff --git a/libs/langchain/tests/unit_tests/load/test_load.py b/libs/langchain/tests/unit_tests/load/test_load.py index d054115836a..aa3fb1065ba 100644 --- a/libs/langchain/tests/unit_tests/load/test_load.py +++ b/libs/langchain/tests/unit_tests/load/test_load.py @@ -24,7 +24,10 @@ def test_loads_openai_llm() -> None: from langchain_openai import OpenAI llm = CommunityOpenAI( - model="davinci", temperature=0.5, openai_api_key="hello", top_p=0.8 + model="davinci", + temperature=0.5, + openai_api_key="hello", + top_p=0.8, ) llm_string = dumps(llm) llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"}) @@ -40,7 +43,10 @@ def test_loads_llmchain() -> None: from langchain_openai import OpenAI llm = CommunityOpenAI( - model="davinci", temperature=0.5, openai_api_key="hello", top_p=0.8 + model="davinci", + temperature=0.5, + openai_api_key="hello", + top_p=0.8, ) prompt = PromptTemplate.from_template("hello {name}!") chain = LLMChain(llm=llm, prompt=prompt) diff --git a/libs/langchain/tests/unit_tests/memory/test_combined_memory.py b/libs/langchain/tests/unit_tests/memory/test_combined_memory.py index 05c37a06131..0e5df20a9bf 100644 --- a/libs/langchain/tests/unit_tests/memory/test_combined_memory.py +++ b/libs/langchain/tests/unit_tests/memory/test_combined_memory.py @@ -7,7 +7,7 @@ import pytest from langchain.memory import CombinedMemory, ConversationBufferMemory -@pytest.fixture() +@pytest.fixture def example_memory() -> list[ConversationBufferMemory]: example_1 = ConversationBufferMemory(memory_key="foo") example_2 = ConversationBufferMemory(memory_key="bar") @@ -21,7 +21,8 @@ def test_basic_functionality(example_memory: list[ConversationBufferMemory]) -> assert combined_memory.memory_variables == ["foo", "bar"] assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""} combined_memory.save_context( - {"input": "Hello there"}, {"output": "Hello, how can I help you?"} + {"input": "Hello there"}, + {"output": "Hello, how can I help you?"}, ) assert combined_memory.load_memory_variables({}) == { "foo": "Human: Hello there\nAI: Hello, how can I help you?", diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py index e6264c2379e..396ea0e58eb 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py @@ -29,13 +29,14 @@ def test_combining_dict_result() -> None: StructuredOutputParser( response_schemas=[ ResponseSchema( - name="answer", description="answer to the user's question" + name="answer", + description="answer to the user's question", ), ResponseSchema( name="source", description="source used to answer the user's question", ), - ] + ], ), RegexParser( regex=r"Confidence: (A|B|C), Explanation: (.*)", @@ -54,13 +55,14 @@ def test_combining_output_parser_output_type() -> None: StructuredOutputParser( response_schemas=[ ResponseSchema( - name="answer", description="answer to the user's question" + name="answer", + description="answer to the user's question", ), ResponseSchema( name="source", description="source used to answer the user's question", ), - ] + ], ), RegexParser( regex=r"Confidence: (A|B|C), Explanation: (.*)", diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py index 287d5ecf79b..f566e413de2 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py @@ -38,7 +38,7 @@ class SuccessfulParseAfterRetriesWithGetFormatInstructions(SuccessfulParseAfterR [ SuccessfulParseAfterRetries(attemp_count_before_success=5), SuccessfulParseAfterRetriesWithGetFormatInstructions( - attemp_count_before_success=5 + attemp_count_before_success=5, ), ], ) @@ -81,7 +81,7 @@ def test_output_fixing_parser_from_llm() -> None: [ SuccessfulParseAfterRetries(attemp_count_before_success=5), SuccessfulParseAfterRetriesWithGetFormatInstructions( - attemp_count_before_success=5 + attemp_count_before_success=5, ), ], ) @@ -140,7 +140,8 @@ def test_output_fixing_parser_output_type( base_parser: BaseOutputParser, ) -> None: parser = OutputFixingParser[str]( - parser=base_parser, retry_chain=RunnablePassthrough() + parser=base_parser, + retry_chain=RunnablePassthrough(), ) assert parser.OutputType is base_parser.OutputType diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_json.py b/libs/langchain/tests/unit_tests/output_parsers/test_json.py index 43f2bca2356..9db2b5d416e 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_json.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_json.py @@ -321,35 +321,35 @@ EXPECTED_STREAMED_JSON_DIFF = [ "op": "replace", "path": "/setup", "value": "Why did the bears start a band called", - } + }, ], [ { "op": "replace", "path": "/setup", "value": "Why did the bears start a band called Bears", - } + }, ], [ { "op": "replace", "path": "/setup", "value": "Why did the bears start a band called Bears Bears", - } + }, ], [ { "op": "replace", "path": "/setup", "value": "Why did the bears start a band called Bears Bears Bears", - } + }, ], [ { "op": "replace", "path": "/setup", "value": "Why did the bears start a band called Bears Bears Bears ?", - } + }, ], [{"op": "add", "path": "/punchline", "value": ""}], [{"op": "replace", "path": "/punchline", "value": "Because"}], @@ -362,35 +362,35 @@ EXPECTED_STREAMED_JSON_DIFF = [ "op": "replace", "path": "/punchline", "value": "Because they wanted to play bear", - } + }, ], [ { "op": "replace", "path": "/punchline", "value": "Because they wanted to play bear -y", - } + }, ], [ { "op": "replace", "path": "/punchline", "value": "Because they wanted to play bear -y good", - } + }, ], [ { "op": "replace", "path": "/punchline", "value": "Because they wanted to play bear -y good music", - } + }, ], [ { "op": "replace", "path": "/punchline", "value": "Because they wanted to play bear -y good music !", - } + }, ], [{"op": "add", "path": "/audience", "value": []}], [{"op": "add", "path": "/audience/0", "value": ""}], @@ -405,7 +405,8 @@ def test_partial_functions_json_output_parser() -> None: def input_iter(_: Any) -> Iterator[AIMessageChunk]: for token in STREAMED_TOKENS: yield AIMessageChunk( - content="", additional_kwargs={"function_call": {"arguments": token}} + content="", + additional_kwargs={"function_call": {"arguments": token}}, ) chain = input_iter | JsonOutputFunctionsParser() @@ -417,7 +418,8 @@ def test_partial_functions_json_output_parser_diff() -> None: def input_iter(_: Any) -> Iterator[AIMessageChunk]: for token in STREAMED_TOKENS: yield AIMessageChunk( - content="", additional_kwargs={"function_call": {"arguments": token}} + content="", + additional_kwargs={"function_call": {"arguments": token}}, ) chain = input_iter | JsonOutputFunctionsParser(diff=True) @@ -429,7 +431,8 @@ async def test_partial_functions_json_output_parser_async() -> None: async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]: for token in STREAMED_TOKENS: yield AIMessageChunk( - content="", additional_kwargs={"function_call": {"arguments": token}} + content="", + additional_kwargs={"function_call": {"arguments": token}}, ) chain = input_iter | JsonOutputFunctionsParser() @@ -441,7 +444,8 @@ async def test_partial_functions_json_output_parser_diff_async() -> None: async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]: for token in STREAMED_TOKENS: yield AIMessageChunk( - content="", additional_kwargs={"function_call": {"arguments": token}} + content="", + additional_kwargs={"function_call": {"arguments": token}}, ) chain = input_iter | JsonOutputFunctionsParser(diff=True) diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py index 8a0562e92e4..704373d85f0 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py @@ -12,7 +12,7 @@ df = pd.DataFrame( "chicken": [1, 2, 3, 4], "veggies": [5, 4, 3, 2], "steak": [9, 8, 7, 6], - } + }, ) parser = PandasDataFrameOutputParser(dataframe=df) diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py b/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py index 88f1b5432df..1843bf2dd25 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py @@ -31,7 +31,8 @@ we expect the result to only contain the following fields: def test_regex_dict_result() -> None: """Test regex dict result.""" regex_dict_parser = RegexDictParser( - output_key_to_format=DEF_OUTPUT_KEY_TO_FORMAT, no_update_value="N/A" + output_key_to_format=DEF_OUTPUT_KEY_TO_FORMAT, + no_update_value="N/A", ) result_dict = regex_dict_parser.parse(DEF_README) print("parse_result:", result_dict) # noqa: T201 @@ -41,6 +42,7 @@ def test_regex_dict_result() -> None: def test_regex_dict_output_type() -> None: """Test regex dict output type.""" regex_dict_parser = RegexDictParser( - output_key_to_format=DEF_OUTPUT_KEY_TO_FORMAT, no_update_value="N/A" + output_key_to_format=DEF_OUTPUT_KEY_TO_FORMAT, + no_update_value="N/A", ) assert regex_dict_parser.OutputType == dict[str, str] diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py index c1fff571780..d94debd8361 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py @@ -69,7 +69,8 @@ async def test_retry_output_parser_aparse_with_prompt() -> None: legacy=False, ) actual = await parser.aparse_with_prompt( - "completion", StringPromptValue(text="dummy") + "completion", + StringPromptValue(text="dummy"), ) assert actual == "parsed" assert base_parser.parse_count == n + 1 @@ -153,7 +154,8 @@ async def test_retry_with_error_output_parser_aparse_with_prompt() -> None: legacy=False, ) actual = await parser.aparse_with_prompt( - "completion", StringPromptValue(text="dummy") + "completion", + StringPromptValue(text="dummy"), ) assert actual == "parsed" assert base_parser.parse_count == n + 1 @@ -211,7 +213,7 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None: NAIVE_RETRY_PROMPT | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"), dt(2024, 7, 8), - ) + ), ], ) def test_retry_output_parser_parse_with_prompt_with_retry_chain( @@ -239,7 +241,7 @@ def test_retry_output_parser_parse_with_prompt_with_retry_chain( NAIVE_RETRY_PROMPT | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"), dt(2024, 7, 8), - ) + ), ], ) async def test_retry_output_parser_aparse_with_prompt_with_retry_chain( @@ -268,7 +270,7 @@ async def test_retry_output_parser_aparse_with_prompt_with_retry_chain( NAIVE_RETRY_WITH_ERROR_PROMPT | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"), dt(2024, 7, 8), - ) + ), ], ) def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain( @@ -297,7 +299,7 @@ def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain( NAIVE_RETRY_WITH_ERROR_PROMPT | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"), dt(2024, 7, 8), - ) + ), ], ) async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain( diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py index 4e6cba09cad..e840d4311a2 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py @@ -21,7 +21,8 @@ class TestModel(BaseModel): action: Actions = Field(description="Action to be performed") action_input: str = Field(description="Input to be used in the action") additional_fields: Optional[str] = Field( - description="Additional fields", default=None + description="Additional fields", + default=None, ) for_new_lines: str = Field(description="To be used to test newlines") @@ -71,7 +72,7 @@ def test_yaml_output_parser(result: str) -> None: """Test yamlOutputParser.""" yaml_parser: YamlOutputParser[TestModel] = YamlOutputParser( - pydantic_object=TestModel + pydantic_object=TestModel, ) model = yaml_parser.parse(result) @@ -83,7 +84,7 @@ def test_yaml_output_parser_fail() -> None: """Test YamlOutputParser where completion result fails schema validation.""" yaml_parser: YamlOutputParser[TestModel] = YamlOutputParser( - pydantic_object=TestModel + pydantic_object=TestModel, ) try: diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_extract.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_extract.py index 87b9898b9b5..9235c9cee59 100644 --- a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_extract.py +++ b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_extract.py @@ -26,11 +26,12 @@ def test_llm_chain_extractor() -> None: "Candlepin bowling is popular in New England.", "Candlepin bowling balls are smaller.", "NO_OUTPUT", - ] + ], ) doc_compressor = LLMChainExtractor.from_llm(llm) output = doc_compressor.compress_documents( - documents, "Tell me about Candlepin bowling." + documents, + "Tell me about Candlepin bowling.", ) expected = documents = [ Document( @@ -38,7 +39,8 @@ def test_llm_chain_extractor() -> None: metadata={"a": 1}, ), Document( - page_content="Candlepin bowling balls are smaller.", metadata={"b": 2} + page_content="Candlepin bowling balls are smaller.", + metadata={"b": 2}, ), ] assert output == expected @@ -66,11 +68,12 @@ async def test_llm_chain_extractor_async() -> None: "Candlepin bowling is popular in New England.", "Candlepin bowling balls are smaller.", "NO_OUTPUT", - ] + ], ) doc_compressor = LLMChainExtractor.from_llm(llm) output = await doc_compressor.acompress_documents( - documents, "Tell me about Candlepin bowling." + documents, + "Tell me about Candlepin bowling.", ) expected = [ Document( @@ -78,7 +81,8 @@ async def test_llm_chain_extractor_async() -> None: metadata={"a": 1}, ), Document( - page_content="Candlepin bowling balls are smaller.", metadata={"b": 2} + page_content="Candlepin bowling balls are smaller.", + metadata={"b": 2}, ), ] assert output == expected diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_filter.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_filter.py index 4020694afa6..669e067f2e8 100644 --- a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_filter.py +++ b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_chain_filter.py @@ -19,7 +19,8 @@ def test_llm_chain_filter() -> None: llm = FakeListChatModel(responses=["YES", "YES", "NO"]) doc_compressor = LLMChainFilter.from_llm(llm) output = doc_compressor.compress_documents( - documents, "Tell me about Candlepin bowling." + documents, + "Tell me about Candlepin bowling.", ) expected = documents[:2] assert output == expected @@ -40,7 +41,8 @@ async def test_llm_chain_extractor_async() -> None: llm = FakeListChatModel(responses=["YES", "YES", "NO"]) doc_compressor = LLMChainFilter.from_llm(llm) output = await doc_compressor.acompress_documents( - documents, "Tell me about Candlepin bowling." + documents, + "Tell me about Candlepin bowling.", ) expected = documents[:2] assert output == expected diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py b/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py index 04e2c6e123f..17793c9fc1d 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py +++ b/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py @@ -45,12 +45,13 @@ class FakeTranslator(Visitor): def visit_comparison(self, comparison: Comparison) -> dict: return { comparison.attribute: { - self._format_func(comparison.comparator): comparison.value - } + self._format_func(comparison.comparator): comparison.value, + }, } def visit_structured_query( - self, structured_query: StructuredQuery + self, + structured_query: StructuredQuery, ) -> tuple[str, dict]: if structured_query.filter is None: kwargs = {} @@ -61,7 +62,10 @@ class FakeTranslator(Visitor): class InMemoryVectorstoreWithSearch(InMemoryVectorStore): def similarity_search( - self, query: str, k: int = 4, **kwargs: Any + self, + query: str, + k: int = 4, + **kwargs: Any, ) -> list[Document]: res = self.store.get(query) if res is None: @@ -69,7 +73,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore): return [res] -@pytest.fixture() +@pytest.fixture def fake_llm() -> FakeLLM: return FakeLLM( queries={ @@ -85,7 +89,7 @@ def fake_llm() -> FakeLLM: ) -@pytest.fixture() +@pytest.fixture def fake_vectorstore() -> InMemoryVectorstoreWithSearch: vectorstore = InMemoryVectorstoreWithSearch() vectorstore.add_documents( @@ -102,9 +106,10 @@ def fake_vectorstore() -> InMemoryVectorstoreWithSearch: return vectorstore -@pytest.fixture() +@pytest.fixture def fake_self_query_retriever( - fake_llm: FakeLLM, fake_vectorstore: InMemoryVectorstoreWithSearch + fake_llm: FakeLLM, + fake_vectorstore: InMemoryVectorstoreWithSearch, ) -> SelfQueryRetriever: return SelfQueryRetriever.from_llm( llm=fake_llm, diff --git a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py index d3061ea56af..225fb469aab 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py @@ -32,7 +32,9 @@ def test_invoke() -> None: retriever2 = MockRetriever(docs=documents2) ensemble_retriever = EnsembleRetriever( - retrievers=[retriever1, retriever2], weights=[0.5, 0.5], id_key=None + retrievers=[retriever1, retriever2], + weights=[0.5, 0.5], + id_key=None, ) ranked_documents = ensemble_retriever.invoke("_") @@ -54,7 +56,9 @@ def test_invoke() -> None: retriever2 = MockRetriever(docs=documents2) ensemble_retriever = EnsembleRetriever( - retrievers=[retriever1, retriever2], weights=[0.5, 0.5], id_key=None + retrievers=[retriever1, retriever2], + weights=[0.5, 0.5], + id_key=None, ) ranked_documents = ensemble_retriever.invoke("_") @@ -77,7 +81,9 @@ def test_invoke() -> None: retriever2 = MockRetriever(docs=documents2) ensemble_retriever = EnsembleRetriever( - retrievers=[retriever1, retriever2], weights=[0.5, 0.5], id_key="id" + retrievers=[retriever1, retriever2], + weights=[0.5, 0.5], + id_key="id", ) ranked_documents = ensemble_retriever.invoke("_") diff --git a/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py b/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py index eb642022811..3c5d282e0ab 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py @@ -16,7 +16,10 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore): return self._identity_fn def similarity_search( - self, query: str, k: int = 4, **kwargs: Any + self, + query: str, + k: int = 4, + **kwargs: Any, ) -> list[Document]: res = self.store.get(query) if res is None: @@ -24,7 +27,10 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore): return [res] def similarity_search_with_score( - self, query: str, k: int = 4, **kwargs: Any + self, + query: str, + k: int = 4, + **kwargs: Any, ) -> list[tuple[Document, float]]: res = self.store.get(query) if res is None: @@ -35,7 +41,9 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore): def test_multi_vector_retriever_initialization() -> None: vectorstore = InMemoryVectorstoreWithSearch() retriever = MultiVectorRetriever( # type: ignore[call-arg] - vectorstore=vectorstore, docstore=InMemoryStore(), doc_id="doc_id" + vectorstore=vectorstore, + docstore=InMemoryStore(), + doc_id="doc_id", ) documents = [Document(page_content="test document", metadata={"doc_id": "1"})] retriever.vectorstore.add_documents(documents, ids=["1"]) @@ -48,7 +56,9 @@ def test_multi_vector_retriever_initialization() -> None: async def test_multi_vector_retriever_initialization_async() -> None: vectorstore = InMemoryVectorstoreWithSearch() retriever = MultiVectorRetriever( # type: ignore[call-arg] - vectorstore=vectorstore, docstore=InMemoryStore(), doc_id="doc_id" + vectorstore=vectorstore, + docstore=InMemoryStore(), + doc_id="doc_id", ) documents = [Document(page_content="test document", metadata={"doc_id": "1"})] await retriever.vectorstore.aadd_documents(documents, ids=["1"]) diff --git a/libs/langchain/tests/unit_tests/retrievers/test_parent_document.py b/libs/langchain/tests/unit_tests/retrievers/test_parent_document.py index b9b8adc141c..a8626a4d248 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_parent_document.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_parent_document.py @@ -11,7 +11,10 @@ from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore class InMemoryVectorstoreWithSearch(InMemoryVectorStore): def similarity_search( - self, query: str, k: int = 4, **kwargs: Any + self, + query: str, + k: int = 4, + **kwargs: Any, ) -> list[Document]: res = self.store.get(query) if res is None: @@ -21,7 +24,8 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore): def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> list[str]: print(documents) # noqa: T201 return super().add_documents( - documents, ids=[f"{i}" for i in range(len(documents))] + documents, + ids=[f"{i}" for i in range(len(documents))], ) diff --git a/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py b/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py index ed8dbb32bee..f400b7a738e 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py @@ -40,7 +40,10 @@ class MockVectorStore(VectorStore): return list(texts) def similarity_search( - self, query: str, k: int = 4, **kwargs: Any + self, + query: str, + k: int = 4, + **kwargs: Any, ) -> list[Document]: return [] @@ -75,7 +78,8 @@ class MockVectorStore(VectorStore): def time_weighted_retriever() -> TimeWeightedVectorStoreRetriever: vectorstore = MockVectorStore() return TimeWeightedVectorStoreRetriever( - vectorstore=vectorstore, memory_stream=_get_example_memories() + vectorstore=vectorstore, + memory_stream=_get_example_memories(), ) @@ -98,7 +102,9 @@ def test_get_combined_score( expected_hours_passed = 2.5 current_time = datetime(2023, 4, 14, 14, 30) combined_score = time_weighted_retriever._get_combined_score( - document, vector_salience, current_time + document, + vector_salience, + current_time, ) expected_score = ( 1.0 - time_weighted_retriever.decay_rate @@ -114,7 +120,7 @@ def test_get_salient_docs( want = [(doc, 0.5) for doc in _get_example_memories()] assert isinstance(docs_and_scores, dict) assert len(docs_and_scores) == len(want) - for k, doc in docs_and_scores.items(): + for doc in docs_and_scores.values(): assert doc in want @@ -126,7 +132,7 @@ async def test_aget_salient_docs( want = [(doc, 0.5) for doc in _get_example_memories()] assert isinstance(docs_and_scores, dict) assert len(docs_and_scores) == len(want) - for k, doc in docs_and_scores.items(): + for doc in docs_and_scores.values(): assert doc in want diff --git a/libs/langchain/tests/unit_tests/runnables/test_hub.py b/libs/langchain/tests/unit_tests/runnables/test_hub.py index a3a8deabd32..43915bdeeff 100644 --- a/libs/langchain/tests/unit_tests/runnables/test_hub.py +++ b/libs/langchain/tests/unit_tests/runnables/test_hub.py @@ -13,7 +13,7 @@ def test_hub_runnable(mock_pull: Mock) -> None: [ ("system", "a"), ("user", "b"), - ] + ], ) basic: HubRunnable = HubRunnable("efriis/my-prompt") @@ -27,13 +27,13 @@ repo_dict = { [ ("system", "a"), ("user", "1"), - ] + ], ), "efriis/my-prompt-2": ChatPromptTemplate.from_messages( [ ("system", "a"), ("user", "2"), - ] + ], ), } @@ -77,6 +77,6 @@ def test_hub_runnable_configurable_fields(mock_pull: Mock) -> None: assert templated_1.messages[1].content == "1" templated_2 = obj_configurable.with_config( - configurable={"owner_repo_commit": "efriis/my-prompt-2"} + configurable={"owner_repo_commit": "efriis/my-prompt-2"}, ).invoke({}) assert templated_2.messages[1].content == "2" diff --git a/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py b/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py index 8906a0488c8..646594f0f26 100644 --- a/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py @@ -31,19 +31,20 @@ class FakeChatOpenAI(BaseChatModel): "function_call": { "name": "accept", "arguments": '{\n "draft": "turtles"\n}', - } + }, }, - ) - ) - ] + ), + ), + ], ) def test_openai_functions_router( - snapshot: SnapshotAssertion, mocker: MockerFixture + snapshot: SnapshotAssertion, + mocker: MockerFixture, ) -> None: revise = mocker.Mock( - side_effect=lambda kw: f"Revised draft: no more {kw['notes']}!" + side_effect=lambda kw: f"Revised draft: no more {kw['notes']}!", ) accept = mocker.Mock(side_effect=lambda kw: f"Accepted draft: {kw['draft']}!") diff --git a/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py b/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py index 4736ee6a6b3..fd32ca240b3 100644 --- a/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py +++ b/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py @@ -192,7 +192,9 @@ def test_run_llm_or_chain_with_input_mapper() -> None: ) assert result == {"output": "2", "the right input": "1"} bad_result = _run_llm_or_chain( - example, {"callbacks": [], "tags": []}, llm_or_chain_factory=lambda: mock_chain + example, + {"callbacks": [], "tags": []}, + llm_or_chain_factory=lambda: mock_chain, ) assert "Error" in bad_result @@ -339,13 +341,13 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: expected = { str(example.id): { "output": { - "result": f"Result for example {uuid.UUID(str(example.id))}" + "result": f"Result for example {uuid.UUID(str(example.id))}", }, "input": {"input": (example.inputs or {}).get("input")}, "reference": { "output": example.outputs["output"] if example.outputs is not None - else None + else None, }, "feedback": [], # No run since we mock the call to the llm above diff --git a/libs/langchain/tests/unit_tests/storage/test_filesystem.py b/libs/langchain/tests/unit_tests/storage/test_filesystem.py index 94d211513cd..ab3bd2a3ac2 100644 --- a/libs/langchain/tests/unit_tests/storage/test_filesystem.py +++ b/libs/langchain/tests/unit_tests/storage/test_filesystem.py @@ -30,7 +30,8 @@ def test_mset_and_mget(file_store: LocalFileStore) -> None: @pytest.mark.parametrize( - "chmod_dir_s, chmod_file_s", [("777", "666"), ("770", "660"), ("700", "600")] + "chmod_dir_s, chmod_file_s", + [("777", "666"), ("770", "660"), ("700", "600")], ) def test_mset_chmod(chmod_dir_s: str, chmod_file_s: str) -> None: chmod_dir = int(chmod_dir_s, base=8) @@ -42,7 +43,9 @@ def test_mset_chmod(chmod_dir_s: str, chmod_file_s: str) -> None: # as the root path temp_dir = os.path.join(temp_dir, "store_dir") file_store = LocalFileStore( - temp_dir, chmod_dir=chmod_dir, chmod_file=chmod_file + temp_dir, + chmod_dir=chmod_dir, + chmod_file=chmod_file, ) # Set values for keys diff --git a/libs/langchain/tests/unit_tests/stubs.py b/libs/langchain/tests/unit_tests/stubs.py index b752364e3af..44450d946bd 100644 --- a/libs/langchain/tests/unit_tests/stubs.py +++ b/libs/langchain/tests/unit_tests/stubs.py @@ -5,7 +5,9 @@ from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage class AnyStr(str): - def __eq__(self, other: Any) -> bool: + __slots__ = () + + def __eq__(self, other: object) -> bool: return isinstance(other, str) diff --git a/libs/langchain/tests/unit_tests/test_dependencies.py b/libs/langchain/tests/unit_tests/test_dependencies.py index 94fab41b522..672910f04d2 100644 --- a/libs/langchain/tests/unit_tests/test_dependencies.py +++ b/libs/langchain/tests/unit_tests/test_dependencies.py @@ -13,7 +13,7 @@ HERE = Path(__file__).parent PYPROJECT_TOML = HERE / "../../pyproject.toml" -@pytest.fixture() +@pytest.fixture def uv_conf() -> dict[str, Any]: """Load the pyproject.toml file.""" with open(PYPROJECT_TOML) as f: @@ -40,7 +40,7 @@ def test_required_dependencies(uv_conf: Mapping[str, Any]) -> None: "langsmith", "pydantic", "requests", - ] + ], ) @@ -83,5 +83,5 @@ def test_test_group_dependencies(uv_conf: Mapping[str, Any]) -> None: # TODO: temporary hack since cffi 1.17.1 doesn't work with py 3.9. "cffi", "numpy", - ] + ], ) diff --git a/libs/langchain/tests/unit_tests/test_schema.py b/libs/langchain/tests/unit_tests/test_schema.py index aa0cc26dd9a..ff8f00d5615 100644 --- a/libs/langchain/tests/unit_tests/test_schema.py +++ b/libs/langchain/tests/unit_tests/test_schema.py @@ -77,7 +77,7 @@ def test_serialization_of_wellknown_objects() -> None: ChatPromptValueConcrete(messages=[AIMessage(content="foo")]), ChatPromptValueConcrete(messages=[HumanMessage(content="human")]), ChatPromptValueConcrete( - messages=[ToolMessage(content="foo", tool_call_id="bar")] + messages=[ToolMessage(content="foo", tool_call_id="bar")], ), ChatPromptValueConcrete(messages=[SystemMessage(content="foo")]), Document(page_content="hello"), diff --git a/libs/langchain/tests/unit_tests/utils/test_iter.py b/libs/langchain/tests/unit_tests/utils/test_iter.py index 99d1ab70760..85b577f10d1 100644 --- a/libs/langchain/tests/unit_tests/utils/test_iter.py +++ b/libs/langchain/tests/unit_tests/utils/test_iter.py @@ -12,7 +12,9 @@ from langchain_core.utils.iter import batch_iterate ], ) def test_batch_iterate( - input_size: int, input_iterable: list[str], expected_output: list[str] + input_size: int, + input_iterable: list[str], + expected_output: list[str], ) -> None: """Test batching function.""" assert list(batch_iterate(input_size, input_iterable)) == expected_output diff --git a/libs/langchain/tests/unit_tests/utils/test_openai_functions.py b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py index ca66e1c64ae..ffe68e64ffa 100644 --- a/libs/langchain/tests/unit_tests/utils/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py @@ -63,7 +63,7 @@ def test_convert_pydantic_to_openai_function_nested() -> None: }, }, "required": ["key"], - } + }, }, "required": ["data"], }, diff --git a/libs/partners/anthropic/pyproject.toml b/libs/partners/anthropic/pyproject.toml index 793d3c0d382..42566e0bfd0 100644 --- a/libs/partners/anthropic/pyproject.toml +++ b/libs/partners/anthropic/pyproject.toml @@ -38,7 +38,7 @@ test = [ "langchain-tests", ] codespell = ["codespell<3.0.0,>=2.2.0"] -lint = ["ruff<1.0,>=0.5"] +lint = ["ruff<0.13,>=0.12.2"] dev = ["langchain-core"] test_integration = ["requests<3.0.0,>=2.32.3", "langchain-core"] typing = [ diff --git a/libs/partners/chroma/pyproject.toml b/libs/partners/chroma/pyproject.toml index 4537b4f1a6f..73825651b47 100644 --- a/libs/partners/chroma/pyproject.toml +++ b/libs/partners/chroma/pyproject.toml @@ -39,7 +39,7 @@ test = [ codespell = ["codespell<3.0.0,>=2.2.0"] test_integration = [] lint = [ - "ruff<1.0,>=0.5", + "ruff<0.13,>=0.12.2", "onnxruntime<1.20; python_version < \"3.10\"", "onnxruntime; python_version >= \"3.10\"", ] diff --git a/libs/partners/deepseek/pyproject.toml b/libs/partners/deepseek/pyproject.toml index d481c01db07..6a0b7cc1968 100644 --- a/libs/partners/deepseek/pyproject.toml +++ b/libs/partners/deepseek/pyproject.toml @@ -32,7 +32,7 @@ test = [ ] codespell = ["codespell<3.0.0,>=2.2.6"] test_integration = [] -lint = ["ruff<1.0,>=0.5"] +lint = ["ruff<0.13,>=0.12.2"] dev = [] typing = ["mypy<2.0,>=1.10"] diff --git a/libs/partners/exa/pyproject.toml b/libs/partners/exa/pyproject.toml index 42754ccd584..9dda05d792f 100644 --- a/libs/partners/exa/pyproject.toml +++ b/libs/partners/exa/pyproject.toml @@ -28,7 +28,7 @@ test = [ "langchain-core", ] codespell = ["codespell<3.0.0,>=2.2.0"] -lint = ["ruff<1.0,>=0.5"] +lint = ["ruff<0.13,>=0.12.2"] dev = ["langchain-core"] test_integration = [] typing = [ diff --git a/libs/partners/fireworks/pyproject.toml b/libs/partners/fireworks/pyproject.toml index b0b67da4cd9..8cb7e4aae98 100644 --- a/libs/partners/fireworks/pyproject.toml +++ b/libs/partners/fireworks/pyproject.toml @@ -37,7 +37,7 @@ test = [ ] codespell = ["codespell<3.0.0,>=2.2.0"] test_integration = [] -lint = ["ruff<1.0,>=0.5"] +lint = ["ruff<0.13,>=0.12.2"] dev = ["langchain-core"] typing = ["mypy<2.0,>=1.10", "types-requests<3,>=2", "langchain-core"] diff --git a/libs/partners/groq/pyproject.toml b/libs/partners/groq/pyproject.toml index f7a4003e391..9c999104497 100644 --- a/libs/partners/groq/pyproject.toml +++ b/libs/partners/groq/pyproject.toml @@ -28,7 +28,7 @@ test = [ "langchain-tests", ] codespell = ["codespell<3.0.0,>=2.2.0"] -lint = ["ruff<1.0,>=0.5"] +lint = ["ruff<0.13,>=0.12.2"] dev = ["langchain-core"] test_integration = ["langchain-core"] typing = ["mypy<2.0,>=1.10", "langchain-core"] diff --git a/libs/partners/huggingface/pyproject.toml b/libs/partners/huggingface/pyproject.toml index abafb7d2932..4df97b0c303 100644 --- a/libs/partners/huggingface/pyproject.toml +++ b/libs/partners/huggingface/pyproject.toml @@ -42,7 +42,7 @@ test = [ "langchain-community", ] codespell = ["codespell<3.0.0,>=2.2.0"] -lint = ["ruff<1.0,>=0.5"] +lint = ["ruff<0.13,>=0.12.2"] dev = ["ipykernel<7.0.0,>=6.29.2", "langchain-core"] test_integration = [] typing = ["mypy<2.0,>=1.10", "langchain-core"] diff --git a/libs/partners/mistralai/pyproject.toml b/libs/partners/mistralai/pyproject.toml index b04acdc21ff..1b521b7742b 100644 --- a/libs/partners/mistralai/pyproject.toml +++ b/libs/partners/mistralai/pyproject.toml @@ -33,7 +33,7 @@ test = [ ] test_integration = [] codespell = ["codespell<3.0.0,>=2.2.0"] -lint = ["ruff<1.0,>=0.5"] +lint = ["ruff<0.13,>=0.12.2"] dev = ["langchain-core"] typing = ["mypy<2.0,>=1.10", "langchain-core"] diff --git a/libs/partners/nomic/pyproject.toml b/libs/partners/nomic/pyproject.toml index c079bf02442..c52ce04ea90 100644 --- a/libs/partners/nomic/pyproject.toml +++ b/libs/partners/nomic/pyproject.toml @@ -33,7 +33,7 @@ test = [ ] codespell = ["codespell<3.0.0,>=2.2.0"] test_integration = [] -lint = ["ruff<1.0.0,>=0.1.5"] +lint = ["ruff<0.13,>=0.12.2"] typing = ["mypy<1.0,>=0.991", "langchain-core"] dev = ["langchain-core"] diff --git a/libs/partners/ollama/pyproject.toml b/libs/partners/ollama/pyproject.toml index b5af560b547..7be57841516 100644 --- a/libs/partners/ollama/pyproject.toml +++ b/libs/partners/ollama/pyproject.toml @@ -32,7 +32,7 @@ test = [ ] codespell = ["codespell<3.0.0,>=2.2.6"] test_integration = [] -lint = ["ruff<1.0.0,>=0.1.8"] +lint = ["ruff<0.13,>=0.12.2"] dev = ["langchain-core"] typing = ["mypy<2.0.0,>=1.7.1", "langchain-core"] diff --git a/libs/partners/openai/pyproject.toml b/libs/partners/openai/pyproject.toml index f088188daf4..7b1192a95ce 100644 --- a/libs/partners/openai/pyproject.toml +++ b/libs/partners/openai/pyproject.toml @@ -40,7 +40,7 @@ test = [ "langchain-tests", ] codespell = ["codespell<3.0.0,>=2.2.0"] -lint = ["ruff<1.0,>=0.5"] +lint = ["ruff<0.13,>=0.12.2"] dev = ["langchain-core"] test_integration = [ "httpx<1.0.0,>=0.27.0", diff --git a/libs/partners/perplexity/pyproject.toml b/libs/partners/perplexity/pyproject.toml index 8ca6bf2b2ef..822a725aa19 100644 --- a/libs/partners/perplexity/pyproject.toml +++ b/libs/partners/perplexity/pyproject.toml @@ -36,7 +36,7 @@ test = [ "langchain-tests", ] codespell = ["codespell<3.0.0,>=2.2.0"] -lint = ["ruff<1.0,>=0.5"] +lint = ["ruff<0.13,>=0.12.2"] dev = ["langchain-core"] test_integration = [ "httpx<1.0.0,>=0.27.0", diff --git a/libs/partners/prompty/pyproject.toml b/libs/partners/prompty/pyproject.toml index c13801588ee..c04121a5a79 100644 --- a/libs/partners/prompty/pyproject.toml +++ b/libs/partners/prompty/pyproject.toml @@ -31,7 +31,7 @@ test = [ ] codespell = ["codespell<3.0.0,>=2.2.0"] test_integration = [] -lint = ["ruff<1.0.0,>=0.1.5"] +lint = ["ruff<0.13,>=0.12.2"] dev = ["types-pyyaml<7.0.0.0,>=6.0.12.20240311", "langchain-core"] typing = [ "mypy<1.0,>=0.991", diff --git a/libs/partners/qdrant/pyproject.toml b/libs/partners/qdrant/pyproject.toml index 14dca0ccd2b..8fbfe03677a 100644 --- a/libs/partners/qdrant/pyproject.toml +++ b/libs/partners/qdrant/pyproject.toml @@ -40,7 +40,7 @@ test = [ ] codespell = ["codespell<3.0.0,>=2.2.0"] test_integration = [] -lint = ["ruff<1.0,>=0.5"] +lint = ["ruff<0.13,>=0.12.2"] dev = ["langchain-core"] typing = ["mypy<2.0,>=1.10", "simsimd<7.0.0,>=6.0.0", "langchain-core"] diff --git a/libs/partners/xai/pyproject.toml b/libs/partners/xai/pyproject.toml index 9d5ade6ae6e..aa5f2f25acf 100644 --- a/libs/partners/xai/pyproject.toml +++ b/libs/partners/xai/pyproject.toml @@ -38,7 +38,7 @@ test = [ ] codespell = ["codespell<3.0.0,>=2.2.0"] test_integration = [] -lint = ["ruff<1.0,>=0.5"] +lint = ["ruff<0.13,>=0.12.2"] typing = ["mypy<2.0,>=1.10", "types-requests<3,>=2", "langchain-core"] dev = ["langchain-core"] diff --git a/pyproject.toml b/pyproject.toml index aa00c3090ee..7e55ea1e219 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ readme = "README.md" repository = "https://www.github.com/langchain-ai/langchain" [dependency-groups] -lint = ["ruff<1.0.0,>=0.5.0"] +lint = ["ruff<0.13,>=0.12.2"] dev = [ "langchain-core", "langchain-text-splitters", @@ -83,7 +83,7 @@ ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy extend-include = ["*.ipynb"] [tool.ruff.lint] -select = ["D"] +select = ["D"] # global select for pydocstyle pydocstyle = { convention = "google" } [tool.ruff.lint.per-file-ignores]