langchain[major]: Remove default instantations of LLMs from VectorstoreToolkit (#20794)

Remove default instantiation from vectorstore toolkit.
This commit is contained in:
Eugene Yurtsev 2024-04-23 16:09:14 -04:00 committed by GitHub
parent 42de5168b1
commit 72f720fa38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 9 deletions

View File

@ -8,10 +8,9 @@ from langchain_community.tools.vectorstore.tool import (
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool
from langchain_core.vectorstores import VectorStore
from langchain.tools import BaseTool
class VectorStoreInfo(BaseModel):
"""Information about a VectorStore."""

View File

@ -5,7 +5,6 @@ from abc import abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
from langchain_community.llms.openai import OpenAI
from langchain_core.callbacks import (
CallbackManagerForChainRun,
)
@ -56,11 +55,7 @@ class _ResponseChain(LLMChain):
class _OpenAIResponseChain(_ResponseChain):
"""Chain that generates responses from user input and context."""
llm: OpenAI = Field(
default_factory=lambda: OpenAI(
max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0
)
)
llm: BaseLanguageModel
def _extract_tokens_and_log_probs(
self, generations: List[Generation]
@ -118,7 +113,7 @@ class FlareChain(Chain):
question_generator_chain: QuestionGeneratorChain
"""Chain that generates questions from uncertain spans."""
response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain)
response_chain: _ResponseChain
"""Chain that generates responses from user input and context."""
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
"""Parser that determines whether the chain is finished."""
@ -255,6 +250,14 @@ class FlareChain(Chain):
Returns:
FlareChain class with the given language model.
"""
try:
from langchain_openai import OpenAI
except ImportError:
raise ImportError(
"OpenAI is required for FlareChain. "
"Please install langchain-openai."
"pip install langchain-openai"
)
question_gen_chain = QuestionGeneratorChain(llm=llm)
response_llm = OpenAI(
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0