mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
[nit] Simplify Spark Creation Validation Check A Little Bit (#4761)
- simplify the validation check a little bit. - re-tested in jupyter notebook. Reviewer: @hwchase17
This commit is contained in:
parent
e027a38f33
commit
db6f7ed0ba
@ -1,6 +1,7 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@ -17,7 +18,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import create_spark_dataframe_agent\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"...input your openai api key here...\""
|
||||
@ -25,9 +25,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"23/05/15 20:33:10 WARN Utils: Your hostname, Mikes-Mac-mini.local resolves to a loopback address: 127.0.0.1; using 192.168.68.115 instead (on interface en1)\n",
|
||||
"23/05/15 20:33:10 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
|
||||
"Setting default log level to \"WARN\".\n",
|
||||
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
|
||||
"23/05/15 20:33:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
@ -64,6 +75,7 @@
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from pyspark.sql import SparkSession\n",
|
||||
"from langchain.agents import create_spark_dataframe_agent\n",
|
||||
"\n",
|
||||
"spark = SparkSession.builder.getOrCreate()\n",
|
||||
"csv_file_path = \"titanic.csv\"\n",
|
||||
@ -92,7 +104,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mThought: I need to find out the size of the dataframe\n",
|
||||
"\u001b[32;1m\u001b[1;3mThought: I need to find out how many rows are in the dataframe\n",
|
||||
"Action: python_repl_ast\n",
|
||||
"Action Input: df.count()\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3m891\u001b[0m\n",
|
||||
@ -205,7 +217,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -213,6 +225,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
|
@ -14,9 +14,7 @@ def _validate_spark_df(df: Any) -> bool:
|
||||
try:
|
||||
from pyspark.sql import DataFrame as SparkLocalDataFrame
|
||||
|
||||
if not isinstance(df, SparkLocalDataFrame):
|
||||
return False
|
||||
return True
|
||||
return isinstance(df, SparkLocalDataFrame)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@ -25,9 +23,7 @@ def _validate_spark_connect_df(df: Any) -> bool:
|
||||
try:
|
||||
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
|
||||
|
||||
if not isinstance(df, SparkConnectDataFrame):
|
||||
return False
|
||||
return True
|
||||
return isinstance(df, SparkConnectDataFrame)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user