Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase
eed6494904 pydantic bridge 2023-08-16 18:10:38 -07:00
4 changed files with 36 additions and 5 deletions

View File

@@ -0,0 +1,33 @@
import pydantic
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
if PYDANTIC_V2:
from pydantic import model_validator, ConfigDict
from pydantic import BaseModel as BM
class BaseModel(BM):
model_config = ConfigDict(arbitrary_types_allowed=True)
else:
from pydantic import root_validator as old_root_validator
from pydantic import BaseModel as BM
class BaseModel(BM):
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def root_validator(*args, **kwargs):
if PYDANTIC_V2:
decorator = model_validator
else:
decorator = old_root_validator
# Check if it's being called as @root_validator without ()
if args and callable(args[0]):
return decorator(args[0])
# Otherwise, it's being called with arguments as @root_validator(...)
return decorator

View File

@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from typing import List
from pydantic_v1 import BaseModel
from langchain._pydantic_bridge import BaseModel
from langchain.tools import BaseTool

View File

@@ -39,10 +39,7 @@ class PowerBIToolkit(BaseToolkit):
output_token_limit: Optional[int] = None
tiktoken_model_name: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""

View File

@@ -56,7 +56,8 @@ class ChatGeneration(Generation):
message: BaseMessage
"""The message output by the chat model."""
@root_validator
@root_validator()
@classmethod
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Set the text attribute to be the contents of the message."""
values["text"] = values["message"].content