diff --git a/libs/langchain/langchain/chains/flare/base.py b/libs/langchain/langchain/chains/flare/base.py index 8070fc12374..8beb5b82e2a 100644 --- a/libs/langchain/langchain/chains/flare/base.py +++ b/libs/langchain/langchain/chains/flare/base.py @@ -5,6 +5,7 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np +from langchain_community.llms.openai import OpenAI from langchain_core.callbacks import ( CallbackManagerForChainRun, ) @@ -55,7 +56,11 @@ class _ResponseChain(LLMChain): class _OpenAIResponseChain(_ResponseChain): """Chain that generates responses from user input and context.""" - llm: BaseLanguageModel + llm: OpenAI = Field( + default_factory=lambda: OpenAI( + max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0 + ) + ) def _extract_tokens_and_log_probs( self, generations: List[Generation] @@ -113,7 +118,7 @@ class FlareChain(Chain): question_generator_chain: QuestionGeneratorChain """Chain that generates questions from uncertain spans.""" - response_chain: _ResponseChain + response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain) """Chain that generates responses from user input and context.""" output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser) """Parser that determines whether the chain is finished.""" @@ -250,14 +255,6 @@ class FlareChain(Chain): Returns: FlareChain class with the given language model. """ - try: - from langchain_openai import OpenAI - except ImportError: - raise ImportError( - "OpenAI is required for FlareChain. " - "Please install langchain-openai." - "pip install langchain-openai" - ) question_gen_chain = QuestionGeneratorChain(llm=llm) response_llm = OpenAI( max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0