mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-25 21:03:11 +00:00 
			
		
		
		
	Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
		
			
				
	
	
		
			187 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			187 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| from typing import TYPE_CHECKING, Any, Iterable, List, Optional
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from pyspark.sql import DataFrame, Row, SparkSession
 | |
| 
 | |
| 
 | |
| class SparkSQL:
 | |
|     """SparkSQL is a utility class for interacting with Spark SQL."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         spark_session: Optional[SparkSession] = None,
 | |
|         catalog: Optional[str] = None,
 | |
|         schema: Optional[str] = None,
 | |
|         ignore_tables: Optional[List[str]] = None,
 | |
|         include_tables: Optional[List[str]] = None,
 | |
|         sample_rows_in_table_info: int = 3,
 | |
|     ):
 | |
|         """Initialize a SparkSQL object.
 | |
| 
 | |
|         Args:
 | |
|             spark_session: A SparkSession object.
 | |
|               If not provided, one will be created.
 | |
|             catalog: The catalog to use.
 | |
|               If not provided, the default catalog will be used.
 | |
|             schema: The schema to use.
 | |
|               If not provided, the default schema will be used.
 | |
|             ignore_tables: A list of tables to ignore.
 | |
|               If not provided, all tables will be used.
 | |
|             include_tables: A list of tables to include.
 | |
|               If not provided, all tables will be used.
 | |
|             sample_rows_in_table_info: The number of rows to include in the table info.
 | |
|               Defaults to 3.
 | |
|         """
 | |
|         try:
 | |
|             from pyspark.sql import SparkSession
 | |
|         except ImportError:
 | |
|             raise ImportError(
 | |
|                 "pyspark is not installed. Please install it with `pip install pyspark`"
 | |
|             )
 | |
| 
 | |
|         self._spark = (
 | |
|             spark_session if spark_session else SparkSession.builder.getOrCreate()
 | |
|         )
 | |
|         if catalog is not None:
 | |
|             self._spark.catalog.setCurrentCatalog(catalog)
 | |
|         if schema is not None:
 | |
|             self._spark.catalog.setCurrentDatabase(schema)
 | |
| 
 | |
|         self._all_tables = set(self._get_all_table_names())
 | |
|         self._include_tables = set(include_tables) if include_tables else set()
 | |
|         if self._include_tables:
 | |
|             missing_tables = self._include_tables - self._all_tables
 | |
|             if missing_tables:
 | |
|                 raise ValueError(
 | |
|                     f"include_tables {missing_tables} not found in database"
 | |
|                 )
 | |
|         self._ignore_tables = set(ignore_tables) if ignore_tables else set()
 | |
|         if self._ignore_tables:
 | |
|             missing_tables = self._ignore_tables - self._all_tables
 | |
|             if missing_tables:
 | |
|                 raise ValueError(
 | |
|                     f"ignore_tables {missing_tables} not found in database"
 | |
|                 )
 | |
|         usable_tables = self.get_usable_table_names()
 | |
|         self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
 | |
| 
 | |
|         if not isinstance(sample_rows_in_table_info, int):
 | |
|             raise TypeError("sample_rows_in_table_info must be an integer")
 | |
| 
 | |
|         self._sample_rows_in_table_info = sample_rows_in_table_info
 | |
| 
 | |
|     @classmethod
 | |
|     def from_uri(
 | |
|         cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
 | |
|     ) -> SparkSQL:
 | |
|         """Creating a remote Spark Session via Spark connect.
 | |
|         For example: SparkSQL.from_uri("sc://localhost:15002")
 | |
|         """
 | |
|         try:
 | |
|             from pyspark.sql import SparkSession
 | |
|         except ImportError:
 | |
|             raise ValueError(
 | |
|                 "pyspark is not installed. Please install it with `pip install pyspark`"
 | |
|             )
 | |
| 
 | |
|         spark = SparkSession.builder.remote(database_uri).getOrCreate()
 | |
|         return cls(spark, **kwargs)
 | |
| 
 | |
|     def get_usable_table_names(self) -> Iterable[str]:
 | |
|         """Get names of tables available."""
 | |
|         if self._include_tables:
 | |
|             return self._include_tables
 | |
|         # sorting the result can help LLM understanding it.
 | |
|         return sorted(self._all_tables - self._ignore_tables)
 | |
| 
 | |
|     def _get_all_table_names(self) -> Iterable[str]:
 | |
|         rows = self._spark.sql("SHOW TABLES").select("tableName").collect()
 | |
|         return list(map(lambda row: row.tableName, rows))
 | |
| 
 | |
|     def _get_create_table_stmt(self, table: str) -> str:
 | |
|         statement = (
 | |
|             self._spark.sql(f"SHOW CREATE TABLE {table}").collect()[0].createtab_stmt
 | |
|         )
 | |
|         # Ignore the data source provider and options to reduce the number of tokens.
 | |
|         using_clause_index = statement.find("USING")
 | |
|         return statement[:using_clause_index] + ";"
 | |
| 
 | |
|     def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
 | |
|         all_table_names = self.get_usable_table_names()
 | |
|         if table_names is not None:
 | |
|             missing_tables = set(table_names).difference(all_table_names)
 | |
|             if missing_tables:
 | |
|                 raise ValueError(f"table_names {missing_tables} not found in database")
 | |
|             all_table_names = table_names
 | |
|         tables = []
 | |
|         for table_name in all_table_names:
 | |
|             table_info = self._get_create_table_stmt(table_name)
 | |
|             if self._sample_rows_in_table_info:
 | |
|                 table_info += "\n\n/*"
 | |
|                 table_info += f"\n{self._get_sample_spark_rows(table_name)}\n"
 | |
|                 table_info += "*/"
 | |
|             tables.append(table_info)
 | |
|         final_str = "\n\n".join(tables)
 | |
|         return final_str
 | |
| 
 | |
|     def _get_sample_spark_rows(self, table: str) -> str:
 | |
|         query = f"SELECT * FROM {table} LIMIT {self._sample_rows_in_table_info}"
 | |
|         df = self._spark.sql(query)
 | |
|         columns_str = "\t".join(list(map(lambda f: f.name, df.schema.fields)))
 | |
|         try:
 | |
|             sample_rows = self._get_dataframe_results(df)
 | |
|             # save the sample rows in string format
 | |
|             sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
 | |
|         except Exception:
 | |
|             sample_rows_str = ""
 | |
| 
 | |
|         return (
 | |
|             f"{self._sample_rows_in_table_info} rows from {table} table:\n"
 | |
|             f"{columns_str}\n"
 | |
|             f"{sample_rows_str}"
 | |
|         )
 | |
| 
 | |
|     def _convert_row_as_tuple(self, row: Row) -> tuple:
 | |
|         return tuple(map(str, row.asDict().values()))
 | |
| 
 | |
|     def _get_dataframe_results(self, df: DataFrame) -> list:
 | |
|         return list(map(self._convert_row_as_tuple, df.collect()))
 | |
| 
 | |
|     def run(self, command: str, fetch: str = "all") -> str:
 | |
|         df = self._spark.sql(command)
 | |
|         if fetch == "one":
 | |
|             df = df.limit(1)
 | |
|         return str(self._get_dataframe_results(df))
 | |
| 
 | |
|     def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
 | |
|         """Get information about specified tables.
 | |
| 
 | |
|         Follows best practices as specified in: Rajkumar et al, 2022
 | |
|         (https://arxiv.org/abs/2204.00498)
 | |
| 
 | |
|         If `sample_rows_in_table_info`, the specified number of sample rows will be
 | |
|         appended to each table description. This can increase performance as
 | |
|         demonstrated in the paper.
 | |
|         """
 | |
|         try:
 | |
|             return self.get_table_info(table_names)
 | |
|         except ValueError as e:
 | |
|             """Format the error message"""
 | |
|             return f"Error: {e}"
 | |
| 
 | |
|     def run_no_throw(self, command: str, fetch: str = "all") -> str:
 | |
|         """Execute a SQL command and return a string representing the results.
 | |
| 
 | |
|         If the statement returns rows, a string of the results is returned.
 | |
|         If the statement returns no rows, an empty string is returned.
 | |
| 
 | |
|         If the statement throws an error, the error message is returned.
 | |
|         """
 | |
|         try:
 | |
|             return self.run(command, fetch)
 | |
|         except Exception as e:
 | |
|             """Format the error message"""
 | |
|             return f"Error: {e}"
 |