mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 09:48:04 +00:00
langchain[major]: breaks some chains to remove hidden defaults (#20759)
Breaks some chains in langchain to remove hidden chat model / llm instantiation.
This commit is contained in:
parent
ad6b5f84e5
commit
1c89e45c14
@ -2,7 +2,6 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain_community.agent_toolkits.base import BaseToolkit
|
from langchain_community.agent_toolkits.base import BaseToolkit
|
||||||
from langchain_community.llms.openai import OpenAI
|
|
||||||
from langchain_community.tools.vectorstore.tool import (
|
from langchain_community.tools.vectorstore.tool import (
|
||||||
VectorStoreQATool,
|
VectorStoreQATool,
|
||||||
VectorStoreQAWithSourcesTool,
|
VectorStoreQAWithSourcesTool,
|
||||||
@ -31,7 +30,7 @@ class VectorStoreToolkit(BaseToolkit):
|
|||||||
"""Toolkit for interacting with a Vector Store."""
|
"""Toolkit for interacting with a Vector Store."""
|
||||||
|
|
||||||
vectorstore_info: VectorStoreInfo = Field(exclude=True)
|
vectorstore_info: VectorStoreInfo = Field(exclude=True)
|
||||||
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
|
llm: BaseLanguageModel
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -65,7 +64,7 @@ class VectorStoreRouterToolkit(BaseToolkit):
|
|||||||
"""Toolkit for routing between Vector Stores."""
|
"""Toolkit for routing between Vector Stores."""
|
||||||
|
|
||||||
vectorstores: List[VectorStoreInfo] = Field(exclude=True)
|
vectorstores: List[VectorStoreInfo] = Field(exclude=True)
|
||||||
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
|
llm: BaseLanguageModel
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain_community.llms.openai import OpenAI
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||||
@ -68,8 +67,11 @@ class NatBotChain(Chain):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_default(cls, objective: str, **kwargs: Any) -> NatBotChain:
|
def from_default(cls, objective: str, **kwargs: Any) -> NatBotChain:
|
||||||
"""Load with default LLMChain."""
|
"""Load with default LLMChain."""
|
||||||
llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)
|
raise NotImplementedError(
|
||||||
return cls.from_llm(llm, objective, **kwargs)
|
"This method is no longer implemented. Please use from_llm."
|
||||||
|
"llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)"
|
||||||
|
"For example, NatBotChain.from_llm(llm, objective)"
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
|
@ -6,7 +6,6 @@ from collections import defaultdict
|
|||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
|
||||||
from langchain_community.utilities.openapi import OpenAPISpec
|
from langchain_community.utilities.openapi import OpenAPISpec
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
@ -272,9 +271,12 @@ def get_openapi_chain(
|
|||||||
if isinstance(spec, str):
|
if isinstance(spec, str):
|
||||||
raise ValueError(f"Unable to parse spec from source {spec}")
|
raise ValueError(f"Unable to parse spec from source {spec}")
|
||||||
openai_fns, call_api_fn = openapi_spec_to_openai_fn(spec)
|
openai_fns, call_api_fn = openapi_spec_to_openai_fn(spec)
|
||||||
llm = llm or ChatOpenAI(
|
if not llm:
|
||||||
model="gpt-3.5-turbo-0613",
|
raise ValueError(
|
||||||
)
|
"Must provide an LLM for this chain.For example,\n"
|
||||||
|
"from langchain_openai import ChatOpenAI\n"
|
||||||
|
"llm = ChatOpenAI()\n"
|
||||||
|
)
|
||||||
prompt = prompt or ChatPromptTemplate.from_template(
|
prompt = prompt or ChatPromptTemplate.from_template(
|
||||||
"Use the provided API's to respond to this user query:\n\n{query}"
|
"Use the provided API's to respond to this user query:\n\n{query}"
|
||||||
)
|
)
|
||||||
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
@ -42,6 +41,8 @@ class MultiRetrievalQAChain(MultiRouteChain):
|
|||||||
default_retriever: Optional[BaseRetriever] = None,
|
default_retriever: Optional[BaseRetriever] = None,
|
||||||
default_prompt: Optional[PromptTemplate] = None,
|
default_prompt: Optional[PromptTemplate] = None,
|
||||||
default_chain: Optional[Chain] = None,
|
default_chain: Optional[Chain] = None,
|
||||||
|
*,
|
||||||
|
default_chain_llm: Optional[BaseLanguageModel] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MultiRetrievalQAChain:
|
) -> MultiRetrievalQAChain:
|
||||||
if default_prompt and not default_retriever:
|
if default_prompt and not default_retriever:
|
||||||
@ -78,8 +79,20 @@ class MultiRetrievalQAChain(MultiRouteChain):
|
|||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template=prompt_template, input_variables=["history", "query"]
|
template=prompt_template, input_variables=["history", "query"]
|
||||||
)
|
)
|
||||||
|
if default_chain_llm is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"conversation_llm must be provided if default_chain is not "
|
||||||
|
"specified. This API has been changed to avoid instantiating "
|
||||||
|
"default LLMs on behalf of users."
|
||||||
|
"You can provide a conversation LLM like so:\n"
|
||||||
|
"from langchain_openai import ChatOpenAI\n"
|
||||||
|
"llm = ChatOpenAI()"
|
||||||
|
)
|
||||||
_default_chain = ConversationChain(
|
_default_chain = ConversationChain(
|
||||||
llm=ChatOpenAI(), prompt=prompt, input_key="query", output_key="result"
|
llm=default_chain_llm,
|
||||||
|
prompt=prompt,
|
||||||
|
input_key="query",
|
||||||
|
output_key="result",
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
router_chain=router_chain,
|
router_chain=router_chain,
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
from langchain_community.document_loaders.base import BaseLoader
|
from langchain_community.document_loaders.base import BaseLoader
|
||||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
|
||||||
from langchain_community.llms.openai import OpenAI
|
|
||||||
from langchain_community.vectorstores.inmemory import InMemoryVectorStore
|
from langchain_community.vectorstores.inmemory import InMemoryVectorStore
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
@ -38,7 +36,14 @@ class VectorStoreIndexWrapper(BaseModel):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Query the vectorstore."""
|
"""Query the vectorstore."""
|
||||||
llm = llm or OpenAI(temperature=0)
|
if llm is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"This API has been changed to require an LLM. "
|
||||||
|
"Please provide an llm to use for querying the vectorstore.\n"
|
||||||
|
"For example,\n"
|
||||||
|
"from langchain_openai import OpenAI\n"
|
||||||
|
"llm = OpenAI(temperature=0)"
|
||||||
|
)
|
||||||
retriever_kwargs = retriever_kwargs or {}
|
retriever_kwargs = retriever_kwargs or {}
|
||||||
chain = RetrievalQA.from_chain_type(
|
chain = RetrievalQA.from_chain_type(
|
||||||
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
|
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
|
||||||
@ -53,7 +58,14 @@ class VectorStoreIndexWrapper(BaseModel):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Query the vectorstore."""
|
"""Query the vectorstore."""
|
||||||
llm = llm or OpenAI(temperature=0)
|
if llm is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"This API has been changed to require an LLM. "
|
||||||
|
"Please provide an llm to use for querying the vectorstore.\n"
|
||||||
|
"For example,\n"
|
||||||
|
"from langchain_openai import OpenAI\n"
|
||||||
|
"llm = OpenAI(temperature=0)"
|
||||||
|
)
|
||||||
retriever_kwargs = retriever_kwargs or {}
|
retriever_kwargs = retriever_kwargs or {}
|
||||||
chain = RetrievalQA.from_chain_type(
|
chain = RetrievalQA.from_chain_type(
|
||||||
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
|
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
|
||||||
@ -68,7 +80,14 @@ class VectorStoreIndexWrapper(BaseModel):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Query the vectorstore and get back sources."""
|
"""Query the vectorstore and get back sources."""
|
||||||
llm = llm or OpenAI(temperature=0)
|
if llm is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"This API has been changed to require an LLM. "
|
||||||
|
"Please provide an llm to use for querying the vectorstore.\n"
|
||||||
|
"For example,\n"
|
||||||
|
"from langchain_openai import OpenAI\n"
|
||||||
|
"llm = OpenAI(temperature=0)"
|
||||||
|
)
|
||||||
retriever_kwargs = retriever_kwargs or {}
|
retriever_kwargs = retriever_kwargs or {}
|
||||||
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
||||||
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
|
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
|
||||||
@ -83,7 +102,14 @@ class VectorStoreIndexWrapper(BaseModel):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Query the vectorstore and get back sources."""
|
"""Query the vectorstore and get back sources."""
|
||||||
llm = llm or OpenAI(temperature=0)
|
if llm is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"This API has been changed to require an LLM. "
|
||||||
|
"Please provide an llm to use for querying the vectorstore.\n"
|
||||||
|
"For example,\n"
|
||||||
|
"from langchain_openai import OpenAI\n"
|
||||||
|
"llm = OpenAI(temperature=0)"
|
||||||
|
)
|
||||||
retriever_kwargs = retriever_kwargs or {}
|
retriever_kwargs = retriever_kwargs or {}
|
||||||
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
||||||
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
|
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
|
||||||
@ -95,7 +121,7 @@ class VectorstoreIndexCreator(BaseModel):
|
|||||||
"""Logic for creating indexes."""
|
"""Logic for creating indexes."""
|
||||||
|
|
||||||
vectorstore_cls: Type[VectorStore] = InMemoryVectorStore
|
vectorstore_cls: Type[VectorStore] = InMemoryVectorStore
|
||||||
embedding: Embeddings = Field(default_factory=OpenAIEmbeddings)
|
embedding: Embeddings
|
||||||
text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter)
|
text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter)
|
||||||
vectorstore_kwargs: dict = Field(default_factory=dict)
|
vectorstore_kwargs: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user