mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-25 21:03:11 +00:00 
			
		
		
		
	Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
		
			
				
	
	
		
			132 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			132 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # flake8: noqa
 | |
| """Tools for interacting with Spark SQL."""
 | |
| from typing import Any, Dict, Optional
 | |
| 
 | |
| from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
 | |
| 
 | |
| from langchain_core.language_models import BaseLanguageModel
 | |
| from langchain_core.callbacks import (
 | |
|     AsyncCallbackManagerForToolRun,
 | |
|     CallbackManagerForToolRun,
 | |
| )
 | |
| from langchain_core.prompts import PromptTemplate
 | |
| from langchain_community.utilities.spark_sql import SparkSQL
 | |
| from langchain_core.tools import BaseTool
 | |
| from langchain_community.tools.spark_sql.prompt import QUERY_CHECKER
 | |
| 
 | |
| 
 | |
| class BaseSparkSQLTool(BaseModel):
 | |
|     """Base tool for interacting with Spark SQL."""
 | |
| 
 | |
|     db: SparkSQL = Field(exclude=True)
 | |
| 
 | |
|     class Config(BaseTool.Config):
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
 | |
|     """Tool for querying a Spark SQL."""
 | |
| 
 | |
|     name: str = "query_sql_db"
 | |
|     description: str = """
 | |
|     Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL.
 | |
|     If the query is not correct, an error message will be returned.
 | |
|     If an error is returned, rewrite the query, check the query, and try again.
 | |
|     """
 | |
| 
 | |
|     def _run(
 | |
|         self,
 | |
|         query: str,
 | |
|         run_manager: Optional[CallbackManagerForToolRun] = None,
 | |
|     ) -> str:
 | |
|         """Execute the query, return the results or an error message."""
 | |
|         return self.db.run_no_throw(query)
 | |
| 
 | |
| 
 | |
| class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool):
 | |
|     """Tool for getting metadata about a Spark SQL."""
 | |
| 
 | |
|     name: str = "schema_sql_db"
 | |
|     description: str = """
 | |
|     Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
 | |
|     Be sure that the tables actually exist by calling list_tables_sql_db first!
 | |
| 
 | |
|     Example Input: "table1, table2, table3"
 | |
|     """
 | |
| 
 | |
|     def _run(
 | |
|         self,
 | |
|         table_names: str,
 | |
|         run_manager: Optional[CallbackManagerForToolRun] = None,
 | |
|     ) -> str:
 | |
|         """Get the schema for tables in a comma-separated list."""
 | |
|         return self.db.get_table_info_no_throw(table_names.split(", "))
 | |
| 
 | |
| 
 | |
| class ListSparkSQLTool(BaseSparkSQLTool, BaseTool):
 | |
|     """Tool for getting tables names."""
 | |
| 
 | |
|     name: str = "list_tables_sql_db"
 | |
|     description: str = "Input is an empty string, output is a comma separated list of tables in the Spark SQL."
 | |
| 
 | |
|     def _run(
 | |
|         self,
 | |
|         tool_input: str = "",
 | |
|         run_manager: Optional[CallbackManagerForToolRun] = None,
 | |
|     ) -> str:
 | |
|         """Get the schema for a specific table."""
 | |
|         return ", ".join(self.db.get_usable_table_names())
 | |
| 
 | |
| 
 | |
| class QueryCheckerTool(BaseSparkSQLTool, BaseTool):
 | |
|     """Use an LLM to check if a query is correct.
 | |
|     Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
 | |
| 
 | |
|     template: str = QUERY_CHECKER
 | |
|     llm: BaseLanguageModel
 | |
|     llm_chain: Any = Field(init=False)
 | |
|     name: str = "query_checker_sql_db"
 | |
|     description: str = """
 | |
|     Use this tool to double check if your query is correct before executing it.
 | |
|     Always use this tool before executing a query with query_sql_db!
 | |
|     """
 | |
| 
 | |
|     @root_validator(pre=True)
 | |
|     def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
 | |
|         if "llm_chain" not in values:
 | |
|             from langchain.chains.llm import LLMChain
 | |
| 
 | |
|             values["llm_chain"] = LLMChain(
 | |
|                 llm=values.get("llm"),
 | |
|                 prompt=PromptTemplate(
 | |
|                     template=QUERY_CHECKER, input_variables=["query"]
 | |
|                 ),
 | |
|             )
 | |
| 
 | |
|         if values["llm_chain"].prompt.input_variables != ["query"]:
 | |
|             raise ValueError(
 | |
|                 "LLM chain for QueryCheckerTool need to use ['query'] as input_variables "
 | |
|                 "for the embedded prompt"
 | |
|             )
 | |
| 
 | |
|         return values
 | |
| 
 | |
|     def _run(
 | |
|         self,
 | |
|         query: str,
 | |
|         run_manager: Optional[CallbackManagerForToolRun] = None,
 | |
|     ) -> str:
 | |
|         """Use the LLM to check the query."""
 | |
|         return self.llm_chain.predict(
 | |
|             query=query, callbacks=run_manager.get_child() if run_manager else None
 | |
|         )
 | |
| 
 | |
|     async def _arun(
 | |
|         self,
 | |
|         query: str,
 | |
|         run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
 | |
|     ) -> str:
 | |
|         return await self.llm_chain.apredict(
 | |
|             query=query, callbacks=run_manager.get_child() if run_manager else None
 | |
|         )
 |