langchain[major]: Remove default instantations of LLMs from VectorstoreToolkit (#20794)

Remove default instantiation from vectorstore toolkit.
This commit is contained in:
Eugene Yurtsev 2024-04-23 16:09:14 -04:00 committed by GitHub
parent 42de5168b1
commit 72f720fa38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 9 deletions

View File

@ -8,10 +8,9 @@ from langchain_community.tools.vectorstore.tool import (
) )
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from langchain.tools import BaseTool
class VectorStoreInfo(BaseModel): class VectorStoreInfo(BaseModel):
"""Information about a VectorStore.""" """Information about a VectorStore."""

View File

@ -5,7 +5,6 @@ from abc import abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np import numpy as np
from langchain_community.llms.openai import OpenAI
from langchain_core.callbacks import ( from langchain_core.callbacks import (
CallbackManagerForChainRun, CallbackManagerForChainRun,
) )
@ -56,11 +55,7 @@ class _ResponseChain(LLMChain):
class _OpenAIResponseChain(_ResponseChain): class _OpenAIResponseChain(_ResponseChain):
"""Chain that generates responses from user input and context.""" """Chain that generates responses from user input and context."""
llm: OpenAI = Field( llm: BaseLanguageModel
default_factory=lambda: OpenAI(
max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0
)
)
def _extract_tokens_and_log_probs( def _extract_tokens_and_log_probs(
self, generations: List[Generation] self, generations: List[Generation]
@ -118,7 +113,7 @@ class FlareChain(Chain):
question_generator_chain: QuestionGeneratorChain question_generator_chain: QuestionGeneratorChain
"""Chain that generates questions from uncertain spans.""" """Chain that generates questions from uncertain spans."""
response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain) response_chain: _ResponseChain
"""Chain that generates responses from user input and context.""" """Chain that generates responses from user input and context."""
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser) output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
"""Parser that determines whether the chain is finished.""" """Parser that determines whether the chain is finished."""
@ -255,6 +250,14 @@ class FlareChain(Chain):
Returns: Returns:
FlareChain class with the given language model. FlareChain class with the given language model.
""" """
try:
from langchain_openai import OpenAI
except ImportError:
raise ImportError(
"OpenAI is required for FlareChain. "
"Please install langchain-openai."
"pip install langchain-openai"
)
question_gen_chain = QuestionGeneratorChain(llm=llm) question_gen_chain = QuestionGeneratorChain(llm=llm)
response_llm = OpenAI( response_llm = OpenAI(
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0 max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0