diff --git a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py index 50b60390f84..4f6004df6d7 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/vectorstore/toolkit.py @@ -8,10 +8,9 @@ from langchain_community.tools.vectorstore.tool import ( ) from langchain_core.language_models import BaseLanguageModel from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.tools import BaseTool from langchain_core.vectorstores import VectorStore -from langchain.tools import BaseTool - class VectorStoreInfo(BaseModel): """Information about a VectorStore.""" diff --git a/libs/langchain/langchain/chains/flare/base.py b/libs/langchain/langchain/chains/flare/base.py index 8beb5b82e2a..8070fc12374 100644 --- a/libs/langchain/langchain/chains/flare/base.py +++ b/libs/langchain/langchain/chains/flare/base.py @@ -5,7 +5,6 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np -from langchain_community.llms.openai import OpenAI from langchain_core.callbacks import ( CallbackManagerForChainRun, ) @@ -56,11 +55,7 @@ class _ResponseChain(LLMChain): class _OpenAIResponseChain(_ResponseChain): """Chain that generates responses from user input and context.""" - llm: OpenAI = Field( - default_factory=lambda: OpenAI( - max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0 - ) - ) + llm: BaseLanguageModel def _extract_tokens_and_log_probs( self, generations: List[Generation] @@ -118,7 +113,7 @@ class FlareChain(Chain): question_generator_chain: QuestionGeneratorChain """Chain that generates questions from uncertain spans.""" - response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain) + response_chain: _ResponseChain """Chain that generates responses from user input and context.""" output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser) """Parser that determines whether the chain is finished.""" @@ -255,6 +250,14 @@ class FlareChain(Chain): Returns: FlareChain class with the given language model. """ + try: + from langchain_openai import OpenAI + except ImportError: + raise ImportError( + "OpenAI is required for FlareChain. " + "Please install langchain-openai." + "pip install langchain-openai" + ) question_gen_chain = QuestionGeneratorChain(llm=llm) response_llm = OpenAI( max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0