mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 08:56:27 +00:00
langchain[major]: Remove default instantations of LLMs from VectorstoreToolkit (#20794)
Remove default instantiation from vectorstore toolkit.
This commit is contained in:
parent
42de5168b1
commit
72f720fa38
@ -8,10 +8,9 @@ from langchain_community.tools.vectorstore.tool import (
|
|||||||
)
|
)
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreInfo(BaseModel):
|
class VectorStoreInfo(BaseModel):
|
||||||
"""Information about a VectorStore."""
|
"""Information about a VectorStore."""
|
||||||
|
@ -5,7 +5,6 @@ from abc import abstractmethod
|
|||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain_community.llms.openai import OpenAI
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
)
|
)
|
||||||
@ -56,11 +55,7 @@ class _ResponseChain(LLMChain):
|
|||||||
class _OpenAIResponseChain(_ResponseChain):
|
class _OpenAIResponseChain(_ResponseChain):
|
||||||
"""Chain that generates responses from user input and context."""
|
"""Chain that generates responses from user input and context."""
|
||||||
|
|
||||||
llm: OpenAI = Field(
|
llm: BaseLanguageModel
|
||||||
default_factory=lambda: OpenAI(
|
|
||||||
max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _extract_tokens_and_log_probs(
|
def _extract_tokens_and_log_probs(
|
||||||
self, generations: List[Generation]
|
self, generations: List[Generation]
|
||||||
@ -118,7 +113,7 @@ class FlareChain(Chain):
|
|||||||
|
|
||||||
question_generator_chain: QuestionGeneratorChain
|
question_generator_chain: QuestionGeneratorChain
|
||||||
"""Chain that generates questions from uncertain spans."""
|
"""Chain that generates questions from uncertain spans."""
|
||||||
response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain)
|
response_chain: _ResponseChain
|
||||||
"""Chain that generates responses from user input and context."""
|
"""Chain that generates responses from user input and context."""
|
||||||
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
||||||
"""Parser that determines whether the chain is finished."""
|
"""Parser that determines whether the chain is finished."""
|
||||||
@ -255,6 +250,14 @@ class FlareChain(Chain):
|
|||||||
Returns:
|
Returns:
|
||||||
FlareChain class with the given language model.
|
FlareChain class with the given language model.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
from langchain_openai import OpenAI
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"OpenAI is required for FlareChain. "
|
||||||
|
"Please install langchain-openai."
|
||||||
|
"pip install langchain-openai"
|
||||||
|
)
|
||||||
question_gen_chain = QuestionGeneratorChain(llm=llm)
|
question_gen_chain = QuestionGeneratorChain(llm=llm)
|
||||||
response_llm = OpenAI(
|
response_llm = OpenAI(
|
||||||
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
|
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
|
||||||
|
Loading…
Reference in New Issue
Block a user