From 487aff7e46ac8b4e83211fc4c5f4a51ced9f46ed Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 3 May 2024 13:02:48 -0400 Subject: [PATCH] langchain[patch]: Revert 20794 until 0.2 release (#21257) PR of 2079 was already released as part of 0.1.17rc. Issue for 0.2 release: https://github.com/langchain-ai/langchain/issues/21080 --- libs/langchain/langchain/chains/flare/base.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/libs/langchain/langchain/chains/flare/base.py b/libs/langchain/langchain/chains/flare/base.py index 8070fc12374..8beb5b82e2a 100644 --- a/libs/langchain/langchain/chains/flare/base.py +++ b/libs/langchain/langchain/chains/flare/base.py @@ -5,6 +5,7 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np +from langchain_community.llms.openai import OpenAI from langchain_core.callbacks import ( CallbackManagerForChainRun, ) @@ -55,7 +56,11 @@ class _ResponseChain(LLMChain): class _OpenAIResponseChain(_ResponseChain): """Chain that generates responses from user input and context.""" - llm: BaseLanguageModel + llm: OpenAI = Field( + default_factory=lambda: OpenAI( + max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0 + ) + ) def _extract_tokens_and_log_probs( self, generations: List[Generation] @@ -113,7 +118,7 @@ class FlareChain(Chain): question_generator_chain: QuestionGeneratorChain """Chain that generates questions from uncertain spans.""" - response_chain: _ResponseChain + response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain) """Chain that generates responses from user input and context.""" output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser) """Parser that determines whether the chain is finished.""" @@ -250,14 +255,6 @@ class FlareChain(Chain): Returns: FlareChain class with the given language model. """ - try: - from langchain_openai import OpenAI - except ImportError: - raise ImportError( - "OpenAI is required for FlareChain. " - "Please install langchain-openai." - "pip install langchain-openai" - ) question_gen_chain = QuestionGeneratorChain(llm=llm) response_llm = OpenAI( max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0