From fd69cc7e42655e89f74b68800c0f042c92621bf3 Mon Sep 17 00:00:00 2001 From: leo-gan Date: Thu, 6 Apr 2023 12:45:16 -0700 Subject: [PATCH] Removed duplicate BaseModel dependencies (#2471) Removed duplicate BaseModel dependencies in class inheritances. Also, sorted imports by `isort`. --- langchain/agents/agent.py | 2 +- langchain/agents/agent_toolkits/openapi/planner.py | 4 +--- .../agent_toolkits/openapi/planner_prompt.py | 1 - langchain/agents/load_tools.py | 14 +++++++------- langchain/agents/react/base.py | 6 ++---- langchain/callbacks/arize_callback.py | 6 +----- langchain/chains/api/base.py | 4 ++-- langchain/chains/combine_documents/base.py | 6 +++--- langchain/chains/combine_documents/map_reduce.py | 4 ++-- langchain/chains/combine_documents/map_rerank.py | 4 ++-- langchain/chains/combine_documents/refine.py | 4 ++-- langchain/chains/combine_documents/stuff.py | 4 ++-- langchain/chains/constitutional_ai/principles.py | 1 + langchain/chains/conversation/base.py | 4 ++-- langchain/chains/conversational_retrieval/base.py | 8 ++++---- langchain/chains/hyde/base.py | 4 ++-- langchain/chains/llm.py | 4 ++-- langchain/chains/llm_bash/base.py | 4 ++-- langchain/chains/llm_checker/base.py | 4 ++-- langchain/chains/llm_math/base.py | 4 ++-- langchain/chains/llm_requests.py | 4 ++-- langchain/chains/llm_summarization_checker/base.py | 4 ++-- langchain/chains/mapreduce.py | 4 ++-- langchain/chains/moderation.py | 4 ++-- langchain/chains/natbot/base.py | 4 ++-- langchain/chains/pal/base.py | 4 ++-- langchain/chains/qa_with_sources/base.py | 6 +++--- langchain/chains/qa_with_sources/retrieval.py | 4 ++-- langchain/chains/qa_with_sources/vector_db.py | 4 ++-- .../chains/question_answering/map_reduce_prompt.py | 11 ++++------- .../chains/question_answering/map_rerank_prompt.py | 2 +- .../chains/question_answering/refine_prompts.py | 14 +++++--------- .../chains/question_answering/stuff_prompt.py | 8 ++------ langchain/chains/retrieval_qa/base.py | 8 ++++---- langchain/chains/sequential.py | 6 +++--- langchain/chains/sql_database/base.py | 6 +++--- langchain/chains/transform.py | 4 +--- langchain/chat_models/azure_openai.py | 4 +--- langchain/chat_models/base.py | 4 ++-- langchain/chat_models/openai.py | 4 ++-- langchain/chat_models/promptlayer_openai.py | 4 +--- langchain/document_loaders/__init__.py | 8 ++------ langchain/embeddings/aleph_alpha.py | 4 +--- langchain/embeddings/self_hosted.py | 4 ++-- langchain/embeddings/self_hosted_hugging_face.py | 4 +--- langchain/evaluation/qa/generate_prompt.py | 2 +- langchain/llms/ai21.py | 2 +- langchain/llms/aleph_alpha.py | 4 ++-- langchain/llms/anthropic.py | 4 ++-- langchain/llms/bananadev.py | 4 ++-- langchain/llms/base.py | 4 ++-- langchain/llms/cerebriumai.py | 4 ++-- langchain/llms/cohere.py | 4 ++-- langchain/llms/deepinfra.py | 4 ++-- langchain/llms/fake.py | 4 +--- langchain/llms/forefrontai.py | 4 ++-- langchain/llms/gooseai.py | 4 ++-- langchain/llms/gpt4all.py | 4 ++-- langchain/llms/huggingface_endpoint.py | 4 ++-- langchain/llms/huggingface_hub.py | 4 ++-- langchain/llms/huggingface_pipeline.py | 4 ++-- langchain/llms/llamacpp.py | 4 ++-- langchain/llms/manifest.py | 4 ++-- langchain/llms/modal.py | 4 ++-- langchain/llms/nlpcloud.py | 4 ++-- langchain/llms/openai.py | 6 +++--- langchain/llms/petals.py | 4 ++-- langchain/llms/promptlayer_openai.py | 6 ++---- langchain/llms/replicate.py | 4 ++-- langchain/llms/sagemaker_endpoint.py | 4 ++-- langchain/llms/self_hosted.py | 4 ++-- langchain/llms/self_hosted_hugging_face.py | 4 ++-- langchain/llms/stochasticai.py | 4 ++-- langchain/llms/writer.py | 4 ++-- langchain/memory/buffer.py | 6 +++--- langchain/memory/buffer_window.py | 4 +--- langchain/memory/chat_memory.py | 5 +---- langchain/memory/combined.py | 4 +--- langchain/memory/entity.py | 4 +--- langchain/memory/kg.py | 4 ++-- langchain/memory/simple.py | 4 +--- langchain/memory/summary.py | 2 +- langchain/memory/summary_buffer.py | 4 ++-- langchain/memory/token_buffer.py | 4 +--- langchain/output_parsers/regex.py | 4 +--- langchain/output_parsers/regex_dict.py | 4 +--- .../prompts/example_selector/ngram_overlap.py | 4 ++-- .../example_selector/semantic_similarity.py | 2 +- langchain/prompts/few_shot.py | 4 ++-- langchain/prompts/few_shot_with_templates.py | 9 +++------ langchain/prompts/prompt.py | 4 ++-- .../vectorstores/test_pgvector.py | 4 +--- tests/unit_tests/agents/test_agent.py | 4 +--- tests/unit_tests/agents/test_react.py | 4 +--- tests/unit_tests/chains/test_base.py | 5 ++--- tests/unit_tests/chains/test_hyde.py | 3 +-- tests/unit_tests/chains/test_natbot.py | 4 +--- tests/unit_tests/chains/test_sequential.py | 3 +-- tests/unit_tests/llms/fake_llm.py | 4 +--- 99 files changed, 187 insertions(+), 257 deletions(-) diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index e6c5ce58a61..3c948642b1f 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -548,7 +548,7 @@ class Agent(BaseSingleActionAgent): } -class AgentExecutor(Chain, BaseModel): +class AgentExecutor(Chain): """Consists of an agent using tools.""" agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] diff --git a/langchain/agents/agent_toolkits/openapi/planner.py b/langchain/agents/agent_toolkits/openapi/planner.py index c08b33629fe..b9ae8d52cab 100644 --- a/langchain/agents/agent_toolkits/openapi/planner.py +++ b/langchain/agents/agent_toolkits/openapi/planner.py @@ -19,9 +19,7 @@ from langchain.agents.agent_toolkits.openapi.planner_prompt import ( REQUESTS_GET_TOOL_DESCRIPTION, REQUESTS_POST_TOOL_DESCRIPTION, ) -from langchain.agents.agent_toolkits.openapi.spec import ( - ReducedOpenAPISpec, -) +from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.tools import Tool from langchain.chains.llm import LLMChain diff --git a/langchain/agents/agent_toolkits/openapi/planner_prompt.py b/langchain/agents/agent_toolkits/openapi/planner_prompt.py index e1fe9fec094..9940a7919ea 100644 --- a/langchain/agents/agent_toolkits/openapi/planner_prompt.py +++ b/langchain/agents/agent_toolkits/openapi/planner_prompt.py @@ -2,7 +2,6 @@ from langchain.prompts.prompt import PromptTemplate - API_PLANNER_PROMPT = """You are a planner that plans a sequence of API calls to assist with user queries against an API. You should: diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index bcfa945abca..df2ad484a3e 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -1,11 +1,11 @@ # flake8: noqa """Load tools.""" -from typing import Any, List, Optional import warnings +from typing import Any, List, Optional from langchain.agents.tools import Tool from langchain.callbacks.base import BaseCallbackManager -from langchain.chains.api import news_docs, open_meteo_docs, tmdb_docs, podcast_docs +from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs from langchain.chains.api.base import APIChain from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.pal.base import PALChain @@ -14,16 +14,16 @@ from langchain.requests import TextRequestsWrapper from langchain.tools.base import BaseTool from langchain.tools.bing_search.tool import BingSearchRun from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun -from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun from langchain.tools.human.tool import HumanInputRun from langchain.tools.python.tool import PythonREPLTool from langchain.tools.requests.tool import ( - RequestsGetTool, - RequestsPostTool, - RequestsPatchTool, - RequestsPutTool, RequestsDeleteTool, + RequestsGetTool, + RequestsPatchTool, + RequestsPostTool, + RequestsPutTool, ) +from langchain.tools.searx_search.tool import SearxSearchResults, SearxSearchRun from langchain.tools.wikipedia.tool import WikipediaQueryRun from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun from langchain.utilities.apify import ApifyWrapper diff --git a/langchain/agents/react/base.py b/langchain/agents/react/base.py index 68ba94768a6..a16fbedec03 100644 --- a/langchain/agents/react/base.py +++ b/langchain/agents/react/base.py @@ -2,8 +2,6 @@ import re from typing import Any, List, Optional, Sequence, Tuple -from pydantic import BaseModel - from langchain.agents.agent import Agent, AgentExecutor from langchain.agents.agent_types import AgentType from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT @@ -16,7 +14,7 @@ from langchain.prompts.base import BasePromptTemplate from langchain.tools.base import BaseTool -class ReActDocstoreAgent(Agent, BaseModel): +class ReActDocstoreAgent(Agent): """Agent for the ReAct chain.""" @property @@ -124,7 +122,7 @@ class DocstoreExplorer: return self.document.page_content.split("\n\n") -class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel): +class ReActTextWorldAgent(ReActDocstoreAgent): """Agent for the ReAct TextWorld chain.""" @classmethod diff --git a/langchain/callbacks/arize_callback.py b/langchain/callbacks/arize_callback.py index 13f6ab85034..c1b4106f40a 100644 --- a/langchain/callbacks/arize_callback.py +++ b/langchain/callbacks/arize_callback.py @@ -77,11 +77,7 @@ class ArizeCallbackHandler(BaseCallbackHandler): def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Log data to Arize when an LLM ends.""" - from arize.utils.types import ( - Embedding, - Environments, - ModelTypes, - ) + from arize.utils.types import Embedding, Environments, ModelTypes # Record token usage of the LLM if response.llm_output is not None: diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 27e093b051b..212982a58d9 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, root_validator +from pydantic import Field, root_validator from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain @@ -13,7 +13,7 @@ from langchain.requests import TextRequestsWrapper from langchain.schema import BaseLanguageModel -class APIChain(Chain, BaseModel): +class APIChain(Chain): """Chain that makes API calls and summarizes the responses to answer a question.""" api_request_chain: LLMChain diff --git a/langchain/chains/combine_documents/base.py b/langchain/chains/combine_documents/base.py index dde16d4531a..dbf03a438e7 100644 --- a/langchain/chains/combine_documents/base.py +++ b/langchain/chains/combine_documents/base.py @@ -3,14 +3,14 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple -from pydantic import BaseModel, Field +from pydantic import Field from langchain.chains.base import Chain from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter -class BaseCombineDocumentsChain(Chain, BaseModel, ABC): +class BaseCombineDocumentsChain(Chain, ABC): """Base interface for chains combining documents.""" input_key: str = "input_documents" #: :meta private: @@ -66,7 +66,7 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC): return extra_return_dict -class AnalyzeDocumentChain(Chain, BaseModel): +class AnalyzeDocumentChain(Chain): """Chain that splits documents, then analyzes it in pieces.""" input_key: str = "input_document" #: :meta private: diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index 9f6d4678dee..19d5478ba28 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain @@ -59,7 +59,7 @@ def _collapse_docs( return Document(page_content=result, metadata=combined_metadata) -class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): +class MapReduceDocumentsChain(BaseCombineDocumentsChain): """Combining documents by mapping a chain over them, then combining results.""" llm_chain: LLMChain diff --git a/langchain/chains/combine_documents/map_rerank.py b/langchain/chains/combine_documents/map_rerank.py index 2eb67e4c52f..35f198a967a 100644 --- a/langchain/chains/combine_documents/map_rerank.py +++ b/langchain/chains/combine_documents/map_rerank.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain @@ -12,7 +12,7 @@ from langchain.docstore.document import Document from langchain.output_parsers.regex import RegexParser -class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel): +class MapRerankDocumentsChain(BaseCombineDocumentsChain): """Combining documents by mapping a chain over them, then reranking results.""" llm_chain: LLMChain diff --git a/langchain/chains/combine_documents/refine.py b/langchain/chains/combine_documents/refine.py index e20ab1474f8..6aba3bc7929 100644 --- a/langchain/chains/combine_documents/refine.py +++ b/langchain/chains/combine_documents/refine.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import Any, Dict, List, Tuple -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain @@ -17,7 +17,7 @@ def _get_default_document_prompt() -> PromptTemplate: return PromptTemplate(input_variables=["page_content"], template="{page_content}") -class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel): +class RefineDocumentsChain(BaseCombineDocumentsChain): """Combine documents by doing a first pass and then refining on more documents.""" initial_llm_chain: LLMChain diff --git a/langchain/chains/combine_documents/stuff.py b/langchain/chains/combine_documents/stuff.py index 4f47a9a4999..efb56862903 100644 --- a/langchain/chains/combine_documents/stuff.py +++ b/langchain/chains/combine_documents/stuff.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain @@ -15,7 +15,7 @@ def _get_default_document_prompt() -> PromptTemplate: return PromptTemplate(input_variables=["page_content"], template="{page_content}") -class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel): +class StuffDocumentsChain(BaseCombineDocumentsChain): """Chain that combines documents by stuffing into context.""" llm_chain: LLMChain diff --git a/langchain/chains/constitutional_ai/principles.py b/langchain/chains/constitutional_ai/principles.py index 536bf32c2fc..ce05b565c0e 100644 --- a/langchain/chains/constitutional_ai/principles.py +++ b/langchain/chains/constitutional_ai/principles.py @@ -1,5 +1,6 @@ # flake8: noqa from typing import Dict + from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple PRINCIPLES: Dict[str, ConstitutionalPrinciple] = {} diff --git a/langchain/chains/conversation/base.py b/langchain/chains/conversation/base.py index a25a7fb2a2e..7d06ab9a568 100644 --- a/langchain/chains/conversation/base.py +++ b/langchain/chains/conversation/base.py @@ -1,7 +1,7 @@ """Chain that carries on a conversation and calls an LLM.""" from typing import Dict, List -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.chains.conversation.prompt import PROMPT from langchain.chains.llm import LLMChain @@ -10,7 +10,7 @@ from langchain.prompts.base import BasePromptTemplate from langchain.schema import BaseMemory -class ConversationChain(LLMChain, BaseModel): +class ConversationChain(LLMChain): """Chain to have a conversation and load context from memory. Example: diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index e58382171e6..b1df5fb1f85 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -6,7 +6,7 @@ from abc import abstractmethod from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -28,7 +28,7 @@ def _get_chat_history(chat_history: List[Tuple[str, str]]) -> str: return buffer -class BaseConversationalRetrievalChain(Chain, BaseModel): +class BaseConversationalRetrievalChain(Chain): """Chain for chatting with an index.""" combine_docs_chain: BaseCombineDocumentsChain @@ -116,7 +116,7 @@ class BaseConversationalRetrievalChain(Chain, BaseModel): super().save(file_path) -class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel): +class ConversationalRetrievalChain(BaseConversationalRetrievalChain): """Chain for chatting with an index.""" retriever: BaseRetriever @@ -175,7 +175,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel): ) -class ChatVectorDBChain(BaseConversationalRetrievalChain, BaseModel): +class ChatVectorDBChain(BaseConversationalRetrievalChain): """Chain for chatting with a vector database.""" vectorstore: VectorStore = Field(alias="vectorstore") diff --git a/langchain/chains/hyde/base.py b/langchain/chains/hyde/base.py index 29ee31de99a..f2f9747032c 100644 --- a/langchain/chains/hyde/base.py +++ b/langchain/chains/hyde/base.py @@ -7,7 +7,7 @@ from __future__ import annotations from typing import Dict, List import numpy as np -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.chains.base import Chain from langchain.chains.hyde.prompts import PROMPT_MAP @@ -16,7 +16,7 @@ from langchain.embeddings.base import Embeddings from langchain.llms.base import BaseLLM -class HypotheticalDocumentEmbedder(Chain, Embeddings, BaseModel): +class HypotheticalDocumentEmbedder(Chain, Embeddings): """Generate hypothetical document for query, and then embed that. Based on https://arxiv.org/abs/2212.10496 diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 62cc7a9112a..eb3963222c9 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.chains.base import Chain from langchain.input import get_colored_text @@ -12,7 +12,7 @@ from langchain.prompts.prompt import PromptTemplate from langchain.schema import BaseLanguageModel, LLMResult, PromptValue -class LLMChain(Chain, BaseModel): +class LLMChain(Chain): """Chain to run queries against LLMs. Example: diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 994df302188..9a9f44b7587 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -1,7 +1,7 @@ """Chain that interprets a prompt and executes bash code to perform bash operations.""" from typing import Dict, List -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -11,7 +11,7 @@ from langchain.schema import BaseLanguageModel from langchain.utilities.bash import BashProcess -class LLMBashChain(Chain, BaseModel): +class LLMBashChain(Chain): """Chain that interprets a prompt and executes bash code to perform bash operations. Example: diff --git a/langchain/chains/llm_checker/base.py b/langchain/chains/llm_checker/base.py index cd2f0eeca2b..0702818ae93 100644 --- a/langchain/chains/llm_checker/base.py +++ b/langchain/chains/llm_checker/base.py @@ -3,7 +3,7 @@ from typing import Dict, List -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -18,7 +18,7 @@ from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate -class LLMCheckerChain(Chain, BaseModel): +class LLMCheckerChain(Chain): """Chain for question-answering with self-verification. Example: diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py index 3f26254b04d..2f8acc5f35a 100644 --- a/langchain/chains/llm_math/base.py +++ b/langchain/chains/llm_math/base.py @@ -1,7 +1,7 @@ """Chain that interprets a prompt and executes python code to do math.""" from typing import Dict, List -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -11,7 +11,7 @@ from langchain.python import PythonREPL from langchain.schema import BaseLanguageModel -class LLMMathChain(Chain, BaseModel): +class LLMMathChain(Chain): """Chain that interprets a prompt and executes python code to do math. Example: diff --git a/langchain/chains/llm_requests.py b/langchain/chains/llm_requests.py index f3f7fb316c4..2ea34cc7ecd 100644 --- a/langchain/chains/llm_requests.py +++ b/langchain/chains/llm_requests.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Dict, List -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.chains import LLMChain from langchain.chains.base import Chain @@ -14,7 +14,7 @@ DEFAULT_HEADERS = { } -class LLMRequestsChain(Chain, BaseModel): +class LLMRequestsChain(Chain): """Chain that hits a URL and then uses an LLM to parse results.""" llm_chain: LLMChain diff --git a/langchain/chains/llm_summarization_checker/base.py b/langchain/chains/llm_summarization_checker/base.py index 656b649fbc5..d69eecb8ae5 100644 --- a/langchain/chains/llm_summarization_checker/base.py +++ b/langchain/chains/llm_summarization_checker/base.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Dict, List -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -27,7 +27,7 @@ ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file( ) -class LLMSummarizationCheckerChain(Chain, BaseModel): +class LLMSummarizationCheckerChain(Chain): """Chain for question-answering with self-verification. Example: diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index 583e484badd..bcaccd3a792 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -7,7 +7,7 @@ from __future__ import annotations from typing import Dict, List -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -20,7 +20,7 @@ from langchain.prompts.base import BasePromptTemplate from langchain.text_splitter import TextSplitter -class MapReduceChain(Chain, BaseModel): +class MapReduceChain(Chain): """Map-reduce chain.""" combine_documents_chain: BaseCombineDocumentsChain diff --git a/langchain/chains/moderation.py b/langchain/chains/moderation.py index a288124c382..de02ee0b1db 100644 --- a/langchain/chains/moderation.py +++ b/langchain/chains/moderation.py @@ -1,13 +1,13 @@ """Pass input through a moderation endpoint.""" from typing import Any, Dict, List, Optional -from pydantic import BaseModel, root_validator +from pydantic import root_validator from langchain.chains.base import Chain from langchain.utils import get_from_dict_or_env -class OpenAIModerationChain(Chain, BaseModel): +class OpenAIModerationChain(Chain): """Pass input through a moderation endpoint. To use, you should have the ``openai`` python package installed, and the diff --git a/langchain/chains/natbot/base.py b/langchain/chains/natbot/base.py index d688c6b7dce..369f0f45bfb 100644 --- a/langchain/chains/natbot/base.py +++ b/langchain/chains/natbot/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Dict, List -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -12,7 +12,7 @@ from langchain.llms.base import BaseLLM from langchain.llms.openai import OpenAI -class NatBotChain(Chain, BaseModel): +class NatBotChain(Chain): """Implement an LLM driven browser. Example: diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 443dd137de4..1bcdef3ae7f 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -6,7 +6,7 @@ from __future__ import annotations from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -17,7 +17,7 @@ from langchain.python import PythonREPL from langchain.schema import BaseLanguageModel -class PALChain(Chain, BaseModel): +class PALChain(Chain): """Implements Program-Aided Language Models.""" llm: BaseLanguageModel diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index ae5efd11d8f..fd3d23723ae 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -6,7 +6,7 @@ import re from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -24,7 +24,7 @@ from langchain.prompts.base import BasePromptTemplate from langchain.schema import BaseLanguageModel -class BaseQAWithSourcesChain(Chain, BaseModel, ABC): +class BaseQAWithSourcesChain(Chain, ABC): """Question answering with sources over documents.""" combine_documents_chain: BaseCombineDocumentsChain @@ -149,7 +149,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): return result -class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): +class QAWithSourcesChain(BaseQAWithSourcesChain): """Question answering with sources over documents.""" input_docs_key: str = "docs" #: :meta private: diff --git a/langchain/chains/qa_with_sources/retrieval.py b/langchain/chains/qa_with_sources/retrieval.py index 6d1944f0eec..1253da94508 100644 --- a/langchain/chains/qa_with_sources/retrieval.py +++ b/langchain/chains/qa_with_sources/retrieval.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List -from pydantic import BaseModel, Field +from pydantic import Field from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain @@ -10,7 +10,7 @@ from langchain.docstore.document import Document from langchain.schema import BaseRetriever -class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): +class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain): """Question-answering with sources over an index.""" retriever: BaseRetriever = Field(exclude=True) diff --git a/langchain/chains/qa_with_sources/vector_db.py b/langchain/chains/qa_with_sources/vector_db.py index 4f53393964c..47439611193 100644 --- a/langchain/chains/qa_with_sources/vector_db.py +++ b/langchain/chains/qa_with_sources/vector_db.py @@ -3,7 +3,7 @@ import warnings from typing import Any, Dict, List -from pydantic import BaseModel, Field, root_validator +from pydantic import Field, root_validator from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain @@ -11,7 +11,7 @@ from langchain.docstore.document import Document from langchain.vectorstores.base import VectorStore -class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): +class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain): """Question-answering with sources over a vector database.""" vectorstore: VectorStore = Field(exclude=True) diff --git a/langchain/chains/question_answering/map_reduce_prompt.py b/langchain/chains/question_answering/map_reduce_prompt.py index 7c0efd77777..9b6153f9e80 100644 --- a/langchain/chains/question_answering/map_reduce_prompt.py +++ b/langchain/chains/question_answering/map_reduce_prompt.py @@ -1,14 +1,11 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model from langchain.prompts.chat import ( - SystemMessagePromptTemplate, - HumanMessagePromptTemplate, ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, ) -from langchain.chains.prompt_selector import ( - ConditionalPromptSelector, - is_chat_model, -) +from langchain.prompts.prompt import PromptTemplate question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question. Return any relevant text verbatim. diff --git a/langchain/chains/question_answering/map_rerank_prompt.py b/langchain/chains/question_answering/map_rerank_prompt.py index 0fd945c4bdf..e73439541d1 100644 --- a/langchain/chains/question_answering/map_rerank_prompt.py +++ b/langchain/chains/question_answering/map_rerank_prompt.py @@ -1,6 +1,6 @@ # flake8: noqa -from langchain.prompts import PromptTemplate from langchain.output_parsers.regex import RegexParser +from langchain.prompts import PromptTemplate output_parser = RegexParser( regex=r"(.*?)\nScore: (.*)", diff --git a/langchain/chains/question_answering/refine_prompts.py b/langchain/chains/question_answering/refine_prompts.py index 78c6dd77d14..87b4863d433 100644 --- a/langchain/chains/question_answering/refine_prompts.py +++ b/langchain/chains/question_answering/refine_prompts.py @@ -1,16 +1,12 @@ # flake8: noqa -from langchain.prompts.prompt import PromptTemplate +from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model from langchain.prompts.chat import ( - SystemMessagePromptTemplate, - HumanMessagePromptTemplate, - ChatPromptTemplate, AIMessagePromptTemplate, + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, ) -from langchain.chains.prompt_selector import ( - ConditionalPromptSelector, - is_chat_model, -) - +from langchain.prompts.prompt import PromptTemplate DEFAULT_REFINE_PROMPT_TMPL = ( "The original question is as follows: {question}\n" diff --git a/langchain/chains/question_answering/stuff_prompt.py b/langchain/chains/question_answering/stuff_prompt.py index 968d2950b69..856907f63ed 100644 --- a/langchain/chains/question_answering/stuff_prompt.py +++ b/langchain/chains/question_answering/stuff_prompt.py @@ -1,16 +1,12 @@ # flake8: noqa +from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model from langchain.prompts import PromptTemplate -from langchain.chains.prompt_selector import ( - ConditionalPromptSelector, - is_chat_model, -) from langchain.prompts.chat import ( ChatPromptTemplate, - SystemMessagePromptTemplate, HumanMessagePromptTemplate, + SystemMessagePromptTemplate, ) - prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. {context} diff --git a/langchain/chains/retrieval_qa/base.py b/langchain/chains/retrieval_qa/base.py index a6bf675d559..cf89c99d452 100644 --- a/langchain/chains/retrieval_qa/base.py +++ b/langchain/chains/retrieval_qa/base.py @@ -5,7 +5,7 @@ import warnings from abc import abstractmethod from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -18,7 +18,7 @@ from langchain.schema import BaseLanguageModel, BaseRetriever, Document from langchain.vectorstores.base import VectorStore -class BaseRetrievalQA(Chain, BaseModel): +class BaseRetrievalQA(Chain): combine_documents_chain: BaseCombineDocumentsChain """Chain to use to combine the documents.""" input_key: str = "query" #: :meta private: @@ -143,7 +143,7 @@ class BaseRetrievalQA(Chain, BaseModel): return {self.output_key: answer} -class RetrievalQA(BaseRetrievalQA, BaseModel): +class RetrievalQA(BaseRetrievalQA): """Chain for question-answering against an index. Example: @@ -166,7 +166,7 @@ class RetrievalQA(BaseRetrievalQA, BaseModel): return await self.retriever.aget_relevant_documents(question) -class VectorDBQA(BaseRetrievalQA, BaseModel): +class VectorDBQA(BaseRetrievalQA): """Chain for question-answering against a vector database.""" vectorstore: VectorStore = Field(exclude=True, alias="vectorstore") diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py index 9b2d2dd41a2..76d699461b5 100644 --- a/langchain/chains/sequential.py +++ b/langchain/chains/sequential.py @@ -1,13 +1,13 @@ """Chain pipeline where the outputs of one step feed directly into next.""" from typing import Dict, List -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.chains.base import Chain from langchain.input import get_color_mapping -class SequentialChain(Chain, BaseModel): +class SequentialChain(Chain): """Chain where the outputs of one chain feed directly into next.""" chains: List[Chain] @@ -94,7 +94,7 @@ class SequentialChain(Chain, BaseModel): return {k: known_values[k] for k in self.output_variables} -class SimpleSequentialChain(Chain, BaseModel): +class SimpleSequentialChain(Chain): """Simple chain where the outputs of one step feed directly into next.""" chains: List[Chain] diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index a91c6e91855..1885ad3eb8e 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, Dict, List -from pydantic import BaseModel, Extra, Field +from pydantic import Extra, Field from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -13,7 +13,7 @@ from langchain.schema import BaseLanguageModel from langchain.sql_database import SQLDatabase -class SQLDatabaseChain(Chain, BaseModel): +class SQLDatabaseChain(Chain): """Chain for interacting with SQL Database. Example: @@ -107,7 +107,7 @@ class SQLDatabaseChain(Chain, BaseModel): return "sql_database_chain" -class SQLDatabaseSequentialChain(Chain, BaseModel): +class SQLDatabaseSequentialChain(Chain): """Chain for querying SQL database that is a sequential chain. The chain is as follows: diff --git a/langchain/chains/transform.py b/langchain/chains/transform.py index f363567163e..eb5cb314a89 100644 --- a/langchain/chains/transform.py +++ b/langchain/chains/transform.py @@ -1,12 +1,10 @@ """Chain that runs an arbitrary python function.""" from typing import Callable, Dict, List -from pydantic import BaseModel - from langchain.chains.base import Chain -class TransformChain(Chain, BaseModel): +class TransformChain(Chain): """Chain transform chain output. Example: diff --git a/langchain/chat_models/azure_openai.py b/langchain/chat_models/azure_openai.py index 37f00d5017e..f4765db09f7 100644 --- a/langchain/chat_models/azure_openai.py +++ b/langchain/chat_models/azure_openai.py @@ -6,9 +6,7 @@ from typing import Any, Dict from pydantic import root_validator -from langchain.chat_models.openai import ( - ChatOpenAI, -) +from langchain.chat_models.openai import ChatOpenAI from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__file__) diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index a9a0a6d673e..91816de1648 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -2,7 +2,7 @@ import asyncio from abc import ABC, abstractmethod from typing import List, Optional -from pydantic import BaseModel, Extra, Field, validator +from pydantic import Extra, Field, validator import langchain from langchain.callbacks import get_callback_manager @@ -23,7 +23,7 @@ def _get_verbosity() -> bool: return langchain.verbose -class BaseChatModel(BaseLanguageModel, BaseModel, ABC): +class BaseChatModel(BaseLanguageModel, ABC): verbose: bool = Field(default_factory=_get_verbosity) """Whether to print out response text.""" callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index c7ee4bd12cc..2d1a0282aac 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -5,7 +5,7 @@ import logging import sys from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from tenacity import ( before_sleep_log, retry, @@ -91,7 +91,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: return message_dict -class ChatOpenAI(BaseChatModel, BaseModel): +class ChatOpenAI(BaseChatModel): """Wrapper around OpenAI Chat large language models. To use, you should have the ``openai`` python package installed, and the diff --git a/langchain/chat_models/promptlayer_openai.py b/langchain/chat_models/promptlayer_openai.py index faf37269346..38b664162d3 100644 --- a/langchain/chat_models/promptlayer_openai.py +++ b/langchain/chat_models/promptlayer_openai.py @@ -2,13 +2,11 @@ import datetime from typing import List, Optional -from pydantic import BaseModel - from langchain.chat_models import ChatOpenAI from langchain.schema import BaseMessage, ChatResult -class PromptLayerChatOpenAI(ChatOpenAI, BaseModel): +class PromptLayerChatOpenAI(ChatOpenAI): """Wrapper around OpenAI Chat large language models and PromptLayer. To use, you should have the ``openai`` and ``promptlayer`` python diff --git a/langchain/document_loaders/__init__.py b/langchain/document_loaders/__init__.py index d2ca8d8e367..c73afdfeb4a 100644 --- a/langchain/document_loaders/__init__.py +++ b/langchain/document_loaders/__init__.py @@ -11,9 +11,7 @@ from langchain.document_loaders.azure_blob_storage_file import ( ) from langchain.document_loaders.bigquery import BigQueryLoader from langchain.document_loaders.blackboard import BlackboardLoader -from langchain.document_loaders.college_confidential import ( - CollegeConfidentialLoader, -) +from langchain.document_loaders.college_confidential import CollegeConfidentialLoader from langchain.document_loaders.conllu import CoNLLULoader from langchain.document_loaders.csv_loader import CSVLoader from langchain.document_loaders.dataframe import DataFrameLoader @@ -66,9 +64,7 @@ from langchain.document_loaders.url import UnstructuredURLLoader from langchain.document_loaders.url_selenium import SeleniumURLLoader from langchain.document_loaders.web_base import WebBaseLoader from langchain.document_loaders.whatsapp_chat import WhatsAppChatLoader -from langchain.document_loaders.word_document import ( - UnstructuredWordDocumentLoader, -) +from langchain.document_loaders.word_document import UnstructuredWordDocumentLoader from langchain.document_loaders.youtube import ( GoogleApiClient, GoogleApiYoutubeLoader, diff --git a/langchain/embeddings/aleph_alpha.py b/langchain/embeddings/aleph_alpha.py index 51fbaad2add..97da5ff3052 100644 --- a/langchain/embeddings/aleph_alpha.py +++ b/langchain/embeddings/aleph_alpha.py @@ -54,9 +54,7 @@ class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings): values, "aleph_alpha_api_key", "ALEPH_ALPHA_API_KEY" ) try: - from aleph_alpha_client import ( - Client, - ) + from aleph_alpha_client import Client except ImportError: raise ValueError( "Could not import aleph_alpha_client python package. " diff --git a/langchain/embeddings/self_hosted.py b/langchain/embeddings/self_hosted.py index 7e05617e25d..c010d5d500a 100644 --- a/langchain/embeddings/self_hosted.py +++ b/langchain/embeddings/self_hosted.py @@ -1,7 +1,7 @@ """Running custom embedding models on self-hosted remote hardware.""" from typing import Any, Callable, List -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.embeddings.base import Embeddings from langchain.llms import SelfHostedPipeline @@ -16,7 +16,7 @@ def _embed_documents(pipeline: Any, *args: Any, **kwargs: Any) -> List[List[floa return pipeline(*args, **kwargs) -class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings, BaseModel): +class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings): """Runs custom embedding models on self-hosted remote hardware. Supported hardware includes auto-launched instances on AWS, GCP, Azure, diff --git a/langchain/embeddings/self_hosted_hugging_face.py b/langchain/embeddings/self_hosted_hugging_face.py index 7675d1e4a04..346f0791672 100644 --- a/langchain/embeddings/self_hosted_hugging_face.py +++ b/langchain/embeddings/self_hosted_hugging_face.py @@ -3,8 +3,6 @@ import importlib import logging from typing import Any, Callable, List, Optional -from pydantic import BaseModel - from langchain.embeddings.self_hosted import SelfHostedEmbeddings DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" @@ -59,7 +57,7 @@ def load_embedding_model(model_id: str, instruct: bool = False, device: int = 0) return client -class SelfHostedHuggingFaceEmbeddings(SelfHostedEmbeddings, BaseModel): +class SelfHostedHuggingFaceEmbeddings(SelfHostedEmbeddings): """Runs sentence_transformers embedding models on self-hosted remote hardware. Supported hardware includes auto-launched instances on AWS, GCP, Azure, diff --git a/langchain/evaluation/qa/generate_prompt.py b/langchain/evaluation/qa/generate_prompt.py index 2fe278cfea7..26d80b2153d 100644 --- a/langchain/evaluation/qa/generate_prompt.py +++ b/langchain/evaluation/qa/generate_prompt.py @@ -1,6 +1,6 @@ # flake8: noqa -from langchain.prompts import PromptTemplate from langchain.output_parsers.regex import RegexParser +from langchain.prompts import PromptTemplate template = """You are a teacher coming up with questions to ask on a quiz. Given the following document, please generate a question and answer based on that document. diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index 1f7736c7887..4ec0326a1bd 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -19,7 +19,7 @@ class AI21PenaltyData(BaseModel): applyToEmojis: bool = True -class AI21(LLM, BaseModel): +class AI21(LLM): """Wrapper around AI21 large language models. To use, you should have the environment variable ``AI21_API_KEY`` diff --git a/langchain/llms/aleph_alpha.py b/langchain/llms/aleph_alpha.py index 810a8c5891d..238622753ab 100644 --- a/langchain/llms/aleph_alpha.py +++ b/langchain/llms/aleph_alpha.py @@ -1,14 +1,14 @@ """Wrapper around Aleph Alpha APIs.""" from typing import Any, Dict, List, Optional, Sequence -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env -class AlephAlpha(LLM, BaseModel): +class AlephAlpha(LLM): """Wrapper around Aleph Alpha large language models. To use, you should have the ``aleph_alpha_client`` python package installed, and the diff --git a/langchain/llms/anthropic.py b/langchain/llms/anthropic.py index d877da52719..f6b7dee801a 100644 --- a/langchain/llms/anthropic.py +++ b/langchain/llms/anthropic.py @@ -2,13 +2,13 @@ import re from typing import Any, Dict, Generator, List, Mapping, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env -class Anthropic(LLM, BaseModel): +class Anthropic(LLM): r"""Wrapper around Anthropic large language models. To use, you should have the ``anthropic`` python package installed, and the diff --git a/langchain/llms/bananadev.py b/langchain/llms/bananadev.py index ae8fa262a0a..697ebcc79af 100644 --- a/langchain/llms/bananadev.py +++ b/langchain/llms/bananadev.py @@ -2,7 +2,7 @@ import logging from typing import Any, Dict, List, Mapping, Optional -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -11,7 +11,7 @@ from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) -class Banana(LLM, BaseModel): +class Banana(LLM): """Wrapper around Banana large language models. To use, you should have the ``banana-dev`` python package installed, diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 45bc2db4fa6..dd2397928f0 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import yaml -from pydantic import BaseModel, Extra, Field, validator +from pydantic import Extra, Field, validator import langchain from langchain.callbacks import get_callback_manager @@ -53,7 +53,7 @@ def update_cache( return llm_output -class BaseLLM(BaseLanguageModel, BaseModel, ABC): +class BaseLLM(BaseLanguageModel, ABC): """LLM wrapper should take in a prompt and return a string.""" cache: Optional[bool] = None diff --git a/langchain/llms/cerebriumai.py b/langchain/llms/cerebriumai.py index 29f0d2fc219..2937d7ffc93 100644 --- a/langchain/llms/cerebriumai.py +++ b/langchain/llms/cerebriumai.py @@ -2,7 +2,7 @@ import logging from typing import Any, Dict, List, Mapping, Optional -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -11,7 +11,7 @@ from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) -class CerebriumAI(LLM, BaseModel): +class CerebriumAI(LLM): """Wrapper around CerebriumAI large language models. To use, you should have the ``cerebrium`` python package installed, and the diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index 2335dba7648..e69aac1d093 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -2,7 +2,7 @@ import logging from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -11,7 +11,7 @@ from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) -class Cohere(LLM, BaseModel): +class Cohere(LLM): """Wrapper around Cohere large language models. To use, you should have the ``cohere`` python package installed, and the diff --git a/langchain/llms/deepinfra.py b/langchain/llms/deepinfra.py index 8993a4bf3b9..55b4c98bb30 100644 --- a/langchain/llms/deepinfra.py +++ b/langchain/llms/deepinfra.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Mapping, Optional import requests -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -11,7 +11,7 @@ from langchain.utils import get_from_dict_or_env DEFAULT_MODEL_ID = "google/flan-t5-xl" -class DeepInfra(LLM, BaseModel): +class DeepInfra(LLM): """Wrapper around DeepInfra deployed models. To use, you should have the ``requests`` python package installed, and the diff --git a/langchain/llms/fake.py b/langchain/llms/fake.py index 96f766f9934..aec4abb9766 100644 --- a/langchain/llms/fake.py +++ b/langchain/llms/fake.py @@ -1,12 +1,10 @@ """Fake LLM wrapper for testing purposes.""" from typing import Any, List, Mapping, Optional -from pydantic import BaseModel - from langchain.llms.base import LLM -class FakeListLLM(LLM, BaseModel): +class FakeListLLM(LLM): """Fake LLM wrapper for testing purposes.""" responses: List diff --git a/langchain/llms/forefrontai.py b/langchain/llms/forefrontai.py index 806bcd85454..1e34377a5ef 100644 --- a/langchain/llms/forefrontai.py +++ b/langchain/llms/forefrontai.py @@ -2,14 +2,14 @@ from typing import Any, Dict, List, Mapping, Optional import requests -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env -class ForefrontAI(LLM, BaseModel): +class ForefrontAI(LLM): """Wrapper around ForefrontAI large language models. To use, you should have the environment variable ``FOREFRONTAI_API_KEY`` diff --git a/langchain/llms/gooseai.py b/langchain/llms/gooseai.py index 89f17f18d32..ec7ca28dc80 100644 --- a/langchain/llms/gooseai.py +++ b/langchain/llms/gooseai.py @@ -2,7 +2,7 @@ import logging from typing import Any, Dict, List, Mapping, Optional -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -10,7 +10,7 @@ from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) -class GooseAI(LLM, BaseModel): +class GooseAI(LLM): """Wrapper around OpenAI large language models. To use, you should have the ``openai`` python package installed, and the diff --git a/langchain/llms/gpt4all.py b/langchain/llms/gpt4all.py index ed927289d15..fa6b5fd5f71 100644 --- a/langchain/llms/gpt4all.py +++ b/langchain/llms/gpt4all.py @@ -1,13 +1,13 @@ """Wrapper for the GPT4All model.""" from typing import Any, Dict, List, Mapping, Optional, Set -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -class GPT4All(LLM, BaseModel): +class GPT4All(LLM): r"""Wrapper around GPT4All language models. To use, you should have the ``pyllamacpp`` python package installed, the diff --git a/langchain/llms/huggingface_endpoint.py b/langchain/llms/huggingface_endpoint.py index 027ff917604..d2682bc0b06 100644 --- a/langchain/llms/huggingface_endpoint.py +++ b/langchain/llms/huggingface_endpoint.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Mapping, Optional import requests -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -11,7 +11,7 @@ from langchain.utils import get_from_dict_or_env VALID_TASKS = ("text2text-generation", "text-generation") -class HuggingFaceEndpoint(LLM, BaseModel): +class HuggingFaceEndpoint(LLM): """Wrapper around HuggingFaceHub Inference Endpoints. To use, you should have the ``huggingface_hub`` python package installed, and the diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py index b9c4098879a..4a8c259e777 100644 --- a/langchain/llms/huggingface_hub.py +++ b/langchain/llms/huggingface_hub.py @@ -1,7 +1,7 @@ """Wrapper around HuggingFace APIs.""" from typing import Any, Dict, List, Mapping, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -11,7 +11,7 @@ DEFAULT_REPO_ID = "gpt2" VALID_TASKS = ("text2text-generation", "text-generation") -class HuggingFaceHub(LLM, BaseModel): +class HuggingFaceHub(LLM): """Wrapper around HuggingFaceHub models. To use, you should have the ``huggingface_hub`` python package installed, and the diff --git a/langchain/llms/huggingface_pipeline.py b/langchain/llms/huggingface_pipeline.py index 5382b31e655..1f3e40c2256 100644 --- a/langchain/llms/huggingface_pipeline.py +++ b/langchain/llms/huggingface_pipeline.py @@ -3,7 +3,7 @@ import importlib.util import logging from typing import Any, List, Mapping, Optional -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -15,7 +15,7 @@ VALID_TASKS = ("text2text-generation", "text-generation") logger = logging.getLogger() -class HuggingFacePipeline(LLM, BaseModel): +class HuggingFacePipeline(LLM): """Wrapper around HuggingFace Pipeline API. To use, you should have the ``transformers`` python package installed. diff --git a/langchain/llms/llamacpp.py b/langchain/llms/llamacpp.py index 878078f6558..af9d9f29f29 100644 --- a/langchain/llms/llamacpp.py +++ b/langchain/llms/llamacpp.py @@ -2,14 +2,14 @@ import logging from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, root_validator +from pydantic import Field, root_validator from langchain.llms.base import LLM logger = logging.getLogger(__name__) -class LlamaCpp(LLM, BaseModel): +class LlamaCpp(LLM): """Wrapper around the llama.cpp model. To use, you should have the llama-cpp-python library installed, and provide the diff --git a/langchain/llms/manifest.py b/langchain/llms/manifest.py index b9a4ce145c4..f44635e1692 100644 --- a/langchain/llms/manifest.py +++ b/langchain/llms/manifest.py @@ -1,12 +1,12 @@ """Wrapper around HazyResearch's Manifest library.""" from typing import Any, Dict, List, Mapping, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.llms.base import LLM -class ManifestWrapper(LLM, BaseModel): +class ManifestWrapper(LLM): """Wrapper around HazyResearch's Manifest library.""" client: Any #: :meta private: diff --git a/langchain/llms/modal.py b/langchain/llms/modal.py index 5037858a31e..4c159a3953a 100644 --- a/langchain/llms/modal.py +++ b/langchain/llms/modal.py @@ -3,7 +3,7 @@ import logging from typing import Any, Dict, List, Mapping, Optional import requests -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens logger = logging.getLogger(__name__) -class Modal(LLM, BaseModel): +class Modal(LLM): """Wrapper around Modal large language models. To use, you should have the ``modal-client`` python package installed. diff --git a/langchain/llms/nlpcloud.py b/langchain/llms/nlpcloud.py index 2c04c41960d..74451d7f929 100644 --- a/langchain/llms/nlpcloud.py +++ b/langchain/llms/nlpcloud.py @@ -1,13 +1,13 @@ """Wrapper around NLPCloud APIs.""" from typing import Any, Dict, List, Mapping, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env -class NLPCloud(LLM, BaseModel): +class NLPCloud(LLM): """Wrapper around NLPCloud large language models. To use, you should have the ``nlpcloud`` python package installed, and the diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 761d19a5ac4..d4d24bda18c 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -17,7 +17,7 @@ from typing import ( Union, ) -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from tenacity import ( before_sleep_log, retry, @@ -113,7 +113,7 @@ async def acompletion_with_retry( return await _completion_with_retry(**kwargs) -class BaseOpenAI(BaseLLM, BaseModel): +class BaseOpenAI(BaseLLM): """Wrapper around OpenAI large language models. To use, you should have the ``openai`` python package installed, and the @@ -534,7 +534,7 @@ class AzureOpenAI(BaseOpenAI): return {**{"engine": self.deployment_name}, **super()._invocation_params} -class OpenAIChat(BaseLLM, BaseModel): +class OpenAIChat(BaseLLM): """Wrapper around OpenAI Chat large language models. To use, you should have the ``openai`` python package installed, and the diff --git a/langchain/llms/petals.py b/langchain/llms/petals.py index bffe59ba817..ed30a28f59b 100644 --- a/langchain/llms/petals.py +++ b/langchain/llms/petals.py @@ -2,7 +2,7 @@ import logging from typing import Any, Dict, List, Mapping, Optional -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -11,7 +11,7 @@ from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) -class Petals(LLM, BaseModel): +class Petals(LLM): """Wrapper around Petals Bloom models. To use, you should have the ``petals`` python package installed, and the diff --git a/langchain/llms/promptlayer_openai.py b/langchain/llms/promptlayer_openai.py index 8a3ae4c1558..c7dd9cf3e01 100644 --- a/langchain/llms/promptlayer_openai.py +++ b/langchain/llms/promptlayer_openai.py @@ -2,13 +2,11 @@ import datetime from typing import List, Optional -from pydantic import BaseModel - from langchain.llms import OpenAI, OpenAIChat from langchain.schema import LLMResult -class PromptLayerOpenAI(OpenAI, BaseModel): +class PromptLayerOpenAI(OpenAI): """Wrapper around OpenAI large language models. To use, you should have the ``openai`` and ``promptlayer`` python @@ -106,7 +104,7 @@ class PromptLayerOpenAI(OpenAI, BaseModel): return generated_responses -class PromptLayerOpenAIChat(OpenAIChat, BaseModel): +class PromptLayerOpenAIChat(OpenAIChat): """Wrapper around OpenAI large language models. To use, you should have the ``openai`` and ``promptlayer`` python diff --git a/langchain/llms/replicate.py b/langchain/llms/replicate.py index 71a6817bba8..42213a49741 100644 --- a/langchain/llms/replicate.py +++ b/langchain/llms/replicate.py @@ -2,7 +2,7 @@ import logging from typing import Any, Dict, List, Mapping, Optional -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -10,7 +10,7 @@ from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) -class Replicate(LLM, BaseModel): +class Replicate(LLM): """Wrapper around Replicate models. To use, you should have the ``replicate`` python package installed, diff --git a/langchain/llms/sagemaker_endpoint.py b/langchain/llms/sagemaker_endpoint.py index 926e17184b4..401b1c88b03 100644 --- a/langchain/llms/sagemaker_endpoint.py +++ b/langchain/llms/sagemaker_endpoint.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Mapping, Optional, Union -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -55,7 +55,7 @@ class ContentHandlerBase(ABC): """ -class SagemakerEndpoint(LLM, BaseModel): +class SagemakerEndpoint(LLM): """Wrapper around custom Sagemaker Inference Endpoints. To use, you must supply the endpoint name from your deployed diff --git a/langchain/llms/self_hosted.py b/langchain/llms/self_hosted.py index 3054329f017..68397da2391 100644 --- a/langchain/llms/self_hosted.py +++ b/langchain/llms/self_hosted.py @@ -4,7 +4,7 @@ import logging import pickle from typing import Any, Callable, List, Mapping, Optional -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -61,7 +61,7 @@ def _send_pipeline_to_device(pipeline: Any, device: int) -> Any: return pipeline -class SelfHostedPipeline(LLM, BaseModel): +class SelfHostedPipeline(LLM): """Run model inference on self-hosted remote hardware. Supported hardware includes auto-launched instances on AWS, GCP, Azure, diff --git a/langchain/llms/self_hosted_hugging_face.py b/langchain/llms/self_hosted_hugging_face.py index 9415b6ca5cb..8138ffbbefa 100644 --- a/langchain/llms/self_hosted_hugging_face.py +++ b/langchain/llms/self_hosted_hugging_face.py @@ -3,7 +3,7 @@ import importlib.util import logging from typing import Any, Callable, List, Mapping, Optional -from pydantic import BaseModel, Extra +from pydantic import Extra from langchain.llms.self_hosted import SelfHostedPipeline from langchain.llms.utils import enforce_stop_tokens @@ -108,7 +108,7 @@ def _load_transformer( return pipeline -class SelfHostedHuggingFaceLLM(SelfHostedPipeline, BaseModel): +class SelfHostedHuggingFaceLLM(SelfHostedPipeline): """Wrapper around HuggingFace Pipeline API to run on self-hosted remote hardware. Supported hardware includes auto-launched instances on AWS, GCP, Azure, diff --git a/langchain/llms/stochasticai.py b/langchain/llms/stochasticai.py index 21c32b21674..052e6efc840 100644 --- a/langchain/llms/stochasticai.py +++ b/langchain/llms/stochasticai.py @@ -4,7 +4,7 @@ import time from typing import Any, Dict, List, Mapping, Optional import requests -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -13,7 +13,7 @@ from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) -class StochasticAI(LLM, BaseModel): +class StochasticAI(LLM): """Wrapper around StochasticAI large language models. To use, you should have the environment variable ``STOCHASTICAI_API_KEY`` diff --git a/langchain/llms/writer.py b/langchain/llms/writer.py index 7959bac6f33..a3a74f5905e 100644 --- a/langchain/llms/writer.py +++ b/langchain/llms/writer.py @@ -2,14 +2,14 @@ from typing import Any, Dict, List, Mapping, Optional import requests -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env -class Writer(LLM, BaseModel): +class Writer(LLM): """Wrapper around Writer large language models. To use, you should have the environment variable ``WRITER_API_KEY`` diff --git a/langchain/memory/buffer.py b/langchain/memory/buffer.py index 0e197f84e68..f3623aaf215 100644 --- a/langchain/memory/buffer.py +++ b/langchain/memory/buffer.py @@ -1,13 +1,13 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel, root_validator +from pydantic import root_validator from langchain.memory.chat_memory import BaseChatMemory, BaseMemory from langchain.memory.utils import get_prompt_input_key from langchain.schema import get_buffer_string -class ConversationBufferMemory(BaseChatMemory, BaseModel): +class ConversationBufferMemory(BaseChatMemory): """Buffer for storing conversation memory.""" human_prefix: str = "Human" @@ -39,7 +39,7 @@ class ConversationBufferMemory(BaseChatMemory, BaseModel): return {self.memory_key: self.buffer} -class ConversationStringBufferMemory(BaseMemory, BaseModel): +class ConversationStringBufferMemory(BaseMemory): """Buffer for storing conversation memory.""" human_prefix: str = "Human" diff --git a/langchain/memory/buffer_window.py b/langchain/memory/buffer_window.py index d76faaddcbf..aaa69b0967b 100644 --- a/langchain/memory/buffer_window.py +++ b/langchain/memory/buffer_window.py @@ -1,12 +1,10 @@ from typing import Any, Dict, List -from pydantic import BaseModel - from langchain.memory.chat_memory import BaseChatMemory from langchain.schema import BaseMessage, get_buffer_string -class ConversationBufferWindowMemory(BaseChatMemory, BaseModel): +class ConversationBufferWindowMemory(BaseChatMemory): """Buffer for storing conversation memory.""" human_prefix: str = "Human" diff --git a/langchain/memory/chat_memory.py b/langchain/memory/chat_memory.py index 3fbf35e7514..cee7d6b93c5 100644 --- a/langchain/memory/chat_memory.py +++ b/langchain/memory/chat_memory.py @@ -5,10 +5,7 @@ from pydantic import Field from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.memory.utils import get_prompt_input_key -from langchain.schema import ( - BaseChatMessageHistory, - BaseMemory, -) +from langchain.schema import BaseChatMessageHistory, BaseMemory class BaseChatMemory(BaseMemory, ABC): diff --git a/langchain/memory/combined.py b/langchain/memory/combined.py index eaee9c36383..7969ca4689e 100644 --- a/langchain/memory/combined.py +++ b/langchain/memory/combined.py @@ -1,11 +1,9 @@ from typing import Any, Dict, List -from pydantic import BaseModel - from langchain.schema import BaseMemory -class CombinedMemory(BaseMemory, BaseModel): +class CombinedMemory(BaseMemory): """Class for combining multiple memories' data together.""" memories: List[BaseMemory] diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index 779a5e1ea22..95aac811a26 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -1,7 +1,5 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel - from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.prompt import ( @@ -13,7 +11,7 @@ from langchain.prompts.base import BasePromptTemplate from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string -class ConversationEntityMemory(BaseChatMemory, BaseModel): +class ConversationEntityMemory(BaseChatMemory): """Entity extractor & summarizer to memory.""" human_prefix: str = "Human" diff --git a/langchain/memory/kg.py b/langchain/memory/kg.py index ddf7ff33daf..8b2b5f6ba47 100644 --- a/langchain/memory/kg.py +++ b/langchain/memory/kg.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Type, Union -from pydantic import BaseModel, Field +from pydantic import Field from langchain.chains.llm import LLMChain from langchain.graphs import NetworkxEntityGraph @@ -20,7 +20,7 @@ from langchain.schema import ( ) -class ConversationKGMemory(BaseChatMemory, BaseModel): +class ConversationKGMemory(BaseChatMemory): """Knowledge graph memory for storing conversation memory. Integrates with external knowledge graph to store and retrieve diff --git a/langchain/memory/simple.py b/langchain/memory/simple.py index c5be80ea02e..c30f70240da 100644 --- a/langchain/memory/simple.py +++ b/langchain/memory/simple.py @@ -1,11 +1,9 @@ from typing import Any, Dict, List -from pydantic import BaseModel - from langchain.schema import BaseMemory -class SimpleMemory(BaseMemory, BaseModel): +class SimpleMemory(BaseMemory): """Simple memory for storing context or other bits of information that shouldn't ever change between prompts. """ diff --git a/langchain/memory/summary.py b/langchain/memory/summary.py index 4f9b27e1663..4873b824b4e 100644 --- a/langchain/memory/summary.py +++ b/langchain/memory/summary.py @@ -34,7 +34,7 @@ class SummarizerMixin(BaseModel): return chain.predict(summary=existing_summary, new_lines=new_lines) -class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin, BaseModel): +class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin): """Conversation summarizer to memory.""" buffer: str = "" diff --git a/langchain/memory/summary_buffer.py b/langchain/memory/summary_buffer.py index 0e5b4734ee1..ac0d9287345 100644 --- a/langchain/memory/summary_buffer.py +++ b/langchain/memory/summary_buffer.py @@ -1,13 +1,13 @@ from typing import Any, Dict, List -from pydantic import BaseModel, root_validator +from pydantic import root_validator from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.summary import SummarizerMixin from langchain.schema import BaseMessage, get_buffer_string -class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel): +class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): """Buffer with summarizer for storing conversation memory.""" max_token_limit: int = 2000 diff --git a/langchain/memory/token_buffer.py b/langchain/memory/token_buffer.py index 3bd9b68410e..bb4da209d92 100644 --- a/langchain/memory/token_buffer.py +++ b/langchain/memory/token_buffer.py @@ -1,12 +1,10 @@ from typing import Any, Dict, List -from pydantic import BaseModel - from langchain.memory.chat_memory import BaseChatMemory from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string -class ConversationTokenBufferMemory(BaseChatMemory, BaseModel): +class ConversationTokenBufferMemory(BaseChatMemory): """Buffer for storing conversation memory.""" human_prefix: str = "Human" diff --git a/langchain/output_parsers/regex.py b/langchain/output_parsers/regex.py index dd03556bb62..c7760cbf23a 100644 --- a/langchain/output_parsers/regex.py +++ b/langchain/output_parsers/regex.py @@ -3,12 +3,10 @@ from __future__ import annotations import re from typing import Dict, List, Optional -from pydantic import BaseModel - from langchain.schema import BaseOutputParser -class RegexParser(BaseOutputParser, BaseModel): +class RegexParser(BaseOutputParser): """Class to parse the output into a dictionary.""" regex: str diff --git a/langchain/output_parsers/regex_dict.py b/langchain/output_parsers/regex_dict.py index d37f25640f4..fc1271a7deb 100644 --- a/langchain/output_parsers/regex_dict.py +++ b/langchain/output_parsers/regex_dict.py @@ -3,12 +3,10 @@ from __future__ import annotations import re from typing import Dict, Optional -from pydantic import BaseModel - from langchain.schema import BaseOutputParser -class RegexDictParser(BaseOutputParser, BaseModel): +class RegexDictParser(BaseOutputParser): """Class to parse the output into a dictionary.""" regex_pattern: str = r"{}:\s?([^.'\n']*)\.?" # : :meta private: diff --git a/langchain/prompts/example_selector/ngram_overlap.py b/langchain/prompts/example_selector/ngram_overlap.py index 335331ec1bf..cfe198d251f 100644 --- a/langchain/prompts/example_selector/ngram_overlap.py +++ b/langchain/prompts/example_selector/ngram_overlap.py @@ -20,8 +20,8 @@ def ngram_overlap_score(source: List[str], example: List[str]) -> float: https://www.nltk.org/_modules/nltk/translate/bleu_score.html https://aclanthology.org/P02-1040.pdf """ - from nltk.translate.bleu_score import ( # type: ignore - SmoothingFunction, + from nltk.translate.bleu_score import ( + SmoothingFunction, # type: ignore sentence_bleu, ) diff --git a/langchain/prompts/example_selector/semantic_similarity.py b/langchain/prompts/example_selector/semantic_similarity.py index 604a04e69c3..b80a9cd6f19 100644 --- a/langchain/prompts/example_selector/semantic_similarity.py +++ b/langchain/prompts/example_selector/semantic_similarity.py @@ -99,7 +99,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): return cls(vectorstore=vectorstore, k=k, input_keys=input_keys) -class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector, BaseModel): +class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector): """ExampleSelector that selects examples based on Max Marginal Relevance. This was shown to improve performance in this paper: diff --git a/langchain/prompts/few_shot.py b/langchain/prompts/few_shot.py index 3b0656b05dd..1f3e60cb5b5 100644 --- a/langchain/prompts/few_shot.py +++ b/langchain/prompts/few_shot.py @@ -1,7 +1,7 @@ """Prompt template that contains few shot examples.""" from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.prompts.base import ( DEFAULT_FORMATTER_MAPPING, @@ -12,7 +12,7 @@ from langchain.prompts.example_selector.base import BaseExampleSelector from langchain.prompts.prompt import PromptTemplate -class FewShotPromptTemplate(StringPromptTemplate, BaseModel): +class FewShotPromptTemplate(StringPromptTemplate): """Prompt template that contains few shot examples.""" examples: Optional[List[dict]] = None diff --git a/langchain/prompts/few_shot_with_templates.py b/langchain/prompts/few_shot_with_templates.py index c37dd19dc99..c305f17182d 100644 --- a/langchain/prompts/few_shot_with_templates.py +++ b/langchain/prompts/few_shot_with_templates.py @@ -1,17 +1,14 @@ """Prompt template that contains few shot examples.""" from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator -from langchain.prompts.base import ( - DEFAULT_FORMATTER_MAPPING, - StringPromptTemplate, -) +from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING, StringPromptTemplate from langchain.prompts.example_selector.base import BaseExampleSelector from langchain.prompts.prompt import PromptTemplate -class FewShotPromptWithTemplates(StringPromptTemplate, BaseModel): +class FewShotPromptWithTemplates(StringPromptTemplate): """Prompt template that contains few shot examples.""" examples: Optional[List[dict]] = None diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index 7210c52f5af..2222dada0c0 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -5,7 +5,7 @@ from pathlib import Path from string import Formatter from typing import Any, Dict, List, Union -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.prompts.base import ( DEFAULT_FORMATTER_MAPPING, @@ -14,7 +14,7 @@ from langchain.prompts.base import ( ) -class PromptTemplate(StringPromptTemplate, BaseModel): +class PromptTemplate(StringPromptTemplate): """Schema to represent a prompt for an LLM. Example: diff --git a/tests/integration_tests/vectorstores/test_pgvector.py b/tests/integration_tests/vectorstores/test_pgvector.py index 023d04d9ecb..3560bb591ef 100644 --- a/tests/integration_tests/vectorstores/test_pgvector.py +++ b/tests/integration_tests/vectorstores/test_pgvector.py @@ -6,9 +6,7 @@ from sqlalchemy.orm import Session from langchain.docstore.document import Document from langchain.vectorstores.pgvector import PGVector -from tests.integration_tests.vectorstores.fake_embeddings import ( - FakeEmbeddings, -) +from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings CONNECTION_STRING = PGVector.connection_string_from_db_params( driver=os.environ.get("TEST_PGVECTOR_DRIVER", "psycopg2"), diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 9a693e7322a..f33573b44f3 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -2,8 +2,6 @@ from typing import Any, List, Mapping, Optional -from pydantic import BaseModel - from langchain.agents import AgentExecutor, AgentType, initialize_agent from langchain.agents.tools import Tool from langchain.callbacks.base import CallbackManager @@ -11,7 +9,7 @@ from langchain.llms.base import LLM from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler -class FakeListLLM(LLM, BaseModel): +class FakeListLLM(LLM): """Fake LLM for testing that outputs elements of a list.""" responses: List[str] diff --git a/tests/unit_tests/agents/test_react.py b/tests/unit_tests/agents/test_react.py index a0a486bcb94..0c54a9ca927 100644 --- a/tests/unit_tests/agents/test_react.py +++ b/tests/unit_tests/agents/test_react.py @@ -2,8 +2,6 @@ from typing import Any, List, Mapping, Optional, Union -from pydantic import BaseModel - from langchain.agents.react.base import ReActChain, ReActDocstoreAgent from langchain.agents.tools import Tool from langchain.docstore.base import Docstore @@ -23,7 +21,7 @@ Made in 2022.""" _FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}") -class FakeListLLM(LLM, BaseModel): +class FakeListLLM(LLM): """Fake LLM for testing that outputs elements of a list.""" responses: List[str] diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index f24fcaf0577..0b0aebf760f 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -2,7 +2,6 @@ from typing import Any, Dict, List, Optional import pytest -from pydantic import BaseModel from langchain.callbacks.base import CallbackManager from langchain.chains.base import Chain @@ -10,7 +9,7 @@ from langchain.schema import BaseMemory from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler -class FakeMemory(BaseMemory, BaseModel): +class FakeMemory(BaseMemory): """Fake memory class for testing purposes.""" @property @@ -33,7 +32,7 @@ class FakeMemory(BaseMemory, BaseModel): pass -class FakeChain(Chain, BaseModel): +class FakeChain(Chain): """Fake chain class for testing purposes.""" be_correct: bool = True diff --git a/tests/unit_tests/chains/test_hyde.py b/tests/unit_tests/chains/test_hyde.py index fd7f3d61893..cc3e6ae42f8 100644 --- a/tests/unit_tests/chains/test_hyde.py +++ b/tests/unit_tests/chains/test_hyde.py @@ -2,7 +2,6 @@ from typing import List, Optional import numpy as np -from pydantic import BaseModel from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.hyde.prompts import PROMPT_MAP @@ -23,7 +22,7 @@ class FakeEmbeddings(Embeddings): return list(np.random.uniform(0, 1, 10)) -class FakeLLM(BaseLLM, BaseModel): +class FakeLLM(BaseLLM): """Fake LLM wrapper for testing purposes.""" n: int = 1 diff --git a/tests/unit_tests/chains/test_natbot.py b/tests/unit_tests/chains/test_natbot.py index 0beaa409ced..fd30901af3f 100644 --- a/tests/unit_tests/chains/test_natbot.py +++ b/tests/unit_tests/chains/test_natbot.py @@ -2,13 +2,11 @@ from typing import Any, List, Mapping, Optional -from pydantic import BaseModel - from langchain.chains.natbot.base import NatBotChain from langchain.llms.base import LLM -class FakeLLM(LLM, BaseModel): +class FakeLLM(LLM): """Fake LLM wrapper for testing purposes.""" def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: diff --git a/tests/unit_tests/chains/test_sequential.py b/tests/unit_tests/chains/test_sequential.py index c6021a1cc1f..2ef0e7d4eea 100644 --- a/tests/unit_tests/chains/test_sequential.py +++ b/tests/unit_tests/chains/test_sequential.py @@ -2,14 +2,13 @@ from typing import Dict, List import pytest -from pydantic import BaseModel from langchain.chains.base import Chain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.memory.simple import SimpleMemory -class FakeChain(Chain, BaseModel): +class FakeChain(Chain): """Fake Chain for testing purposes.""" input_variables: List[str] diff --git a/tests/unit_tests/llms/fake_llm.py b/tests/unit_tests/llms/fake_llm.py index dd8b3462f00..263bc2b6308 100644 --- a/tests/unit_tests/llms/fake_llm.py +++ b/tests/unit_tests/llms/fake_llm.py @@ -1,12 +1,10 @@ """Fake LLM wrapper for testing purposes.""" from typing import Any, List, Mapping, Optional -from pydantic import BaseModel - from langchain.llms.base import LLM -class FakeLLM(LLM, BaseModel): +class FakeLLM(LLM): """Fake LLM wrapper for testing purposes.""" queries: Optional[Mapping] = None