langchain[patch]: Revert 20794 until 0.2 release (#21257)

PR of 2079 was already released as part of 0.1.17rc.


Issue for 0.2 release:
https://github.com/langchain-ai/langchain/issues/21080
This commit is contained in:
Eugene Yurtsev 2024-05-03 13:02:48 -04:00 committed by GitHub
parent ba4a309d98
commit 487aff7e46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,6 +5,7 @@ from abc import abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np import numpy as np
from langchain_community.llms.openai import OpenAI
from langchain_core.callbacks import ( from langchain_core.callbacks import (
CallbackManagerForChainRun, CallbackManagerForChainRun,
) )
@ -55,7 +56,11 @@ class _ResponseChain(LLMChain):
class _OpenAIResponseChain(_ResponseChain): class _OpenAIResponseChain(_ResponseChain):
"""Chain that generates responses from user input and context.""" """Chain that generates responses from user input and context."""
llm: BaseLanguageModel llm: OpenAI = Field(
default_factory=lambda: OpenAI(
max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0
)
)
def _extract_tokens_and_log_probs( def _extract_tokens_and_log_probs(
self, generations: List[Generation] self, generations: List[Generation]
@ -113,7 +118,7 @@ class FlareChain(Chain):
question_generator_chain: QuestionGeneratorChain question_generator_chain: QuestionGeneratorChain
"""Chain that generates questions from uncertain spans.""" """Chain that generates questions from uncertain spans."""
response_chain: _ResponseChain response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain)
"""Chain that generates responses from user input and context.""" """Chain that generates responses from user input and context."""
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser) output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
"""Parser that determines whether the chain is finished.""" """Parser that determines whether the chain is finished."""
@ -250,14 +255,6 @@ class FlareChain(Chain):
Returns: Returns:
FlareChain class with the given language model. FlareChain class with the given language model.
""" """
try:
from langchain_openai import OpenAI
except ImportError:
raise ImportError(
"OpenAI is required for FlareChain. "
"Please install langchain-openai."
"pip install langchain-openai"
)
question_gen_chain = QuestionGeneratorChain(llm=llm) question_gen_chain = QuestionGeneratorChain(llm=llm)
response_llm = OpenAI( response_llm = OpenAI(
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0 max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0