This commit is contained in:
Nuno Campos 2023-10-02 10:11:26 +01:00
parent 52e5a8b43e
commit 040bb2983d
2 changed files with 6 additions and 2 deletions

View File

@ -17,6 +17,7 @@ from langchain.callbacks.manager import (
CallbackManagerForToolRun, CallbackManagerForToolRun,
Callbacks, Callbacks,
) )
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import ( from langchain.pydantic_v1 import (
BaseModel, BaseModel,
Extra, Extra,
@ -165,7 +166,7 @@ class ChildTool(BaseTool):
] = False ] = False
"""Handle the content of the ToolException thrown.""" """Handle the content of the ToolException thrown."""
class Config: class Config(Serializable.Config):
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
arbitrary_types_allowed = True arbitrary_types_allowed = True

View File

@ -2,7 +2,7 @@
"""Tools for interacting with Spark SQL.""" """Tools for interacting with Spark SQL."""
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
@ -21,6 +21,9 @@ class BaseSparkSQLTool(BaseModel):
db: SparkSQL = Field(exclude=True) db: SparkSQL = Field(exclude=True)
class Config(BaseTool.Config):
pass
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool): class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for querying a Spark SQL.""" """Tool for querying a Spark SQL."""