mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-08 16:06:30 +00:00
Type LLMChain.llm as runnable (#12385)
This commit is contained in:
parent
224ec0cfd3
commit
a8c68d4ffa
@ -150,7 +150,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"""
|
"""
|
||||||
inputs = self._get_inputs(docs, **kwargs)
|
inputs = self._get_inputs(docs, **kwargs)
|
||||||
prompt = self.llm_chain.prompt.format(**inputs)
|
prompt = self.llm_chain.prompt.format(**inputs)
|
||||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
return self.llm_chain._get_num_tokens(prompt)
|
||||||
|
|
||||||
def combine_docs(
|
def combine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
|
@ -284,7 +284,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
|||||||
self.combine_docs_chain, StuffDocumentsChain
|
self.combine_docs_chain, StuffDocumentsChain
|
||||||
):
|
):
|
||||||
tokens = [
|
tokens = [
|
||||||
self.combine_docs_chain.llm_chain.llm.get_num_tokens(doc.page_content)
|
self.combine_docs_chain.llm_chain._get_num_tokens(doc.page_content)
|
||||||
for doc in docs
|
for doc in docs
|
||||||
]
|
]
|
||||||
token_count = sum(tokens[:num_docs])
|
token_count = sum(tokens[:num_docs])
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManager,
|
AsyncCallbackManager,
|
||||||
@ -17,12 +17,25 @@ from langchain.prompts.prompt import PromptTemplate
|
|||||||
from langchain.pydantic_v1 import Extra, Field
|
from langchain.pydantic_v1 import Extra, Field
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
BaseLLMOutputParser,
|
BaseLLMOutputParser,
|
||||||
|
BaseMessage,
|
||||||
BasePromptTemplate,
|
BasePromptTemplate,
|
||||||
|
ChatGeneration,
|
||||||
|
Generation,
|
||||||
LLMResult,
|
LLMResult,
|
||||||
PromptValue,
|
PromptValue,
|
||||||
StrOutputParser,
|
StrOutputParser,
|
||||||
)
|
)
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import (
|
||||||
|
BaseLanguageModel,
|
||||||
|
LanguageModelInput,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable import (
|
||||||
|
Runnable,
|
||||||
|
RunnableBinding,
|
||||||
|
RunnableBranch,
|
||||||
|
RunnableWithFallbacks,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.configurable import DynamicRunnable
|
||||||
from langchain.utils.input import get_colored_text
|
from langchain.utils.input import get_colored_text
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +61,9 @@ class LLMChain(Chain):
|
|||||||
|
|
||||||
prompt: BasePromptTemplate
|
prompt: BasePromptTemplate
|
||||||
"""Prompt object to use."""
|
"""Prompt object to use."""
|
||||||
llm: BaseLanguageModel
|
llm: Union[
|
||||||
|
Runnable[LanguageModelInput, str], Runnable[LanguageModelInput, BaseMessage]
|
||||||
|
]
|
||||||
"""Language model to call."""
|
"""Language model to call."""
|
||||||
output_key: str = "text" #: :meta private:
|
output_key: str = "text" #: :meta private:
|
||||||
output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser)
|
output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser)
|
||||||
@ -100,12 +115,25 @@ class LLMChain(Chain):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Generate LLM result from inputs."""
|
"""Generate LLM result from inputs."""
|
||||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||||
return self.llm.generate_prompt(
|
callbacks = run_manager.get_child() if run_manager else None
|
||||||
prompts,
|
if isinstance(self.llm, BaseLanguageModel):
|
||||||
stop,
|
return self.llm.generate_prompt(
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
prompts,
|
||||||
**self.llm_kwargs,
|
stop,
|
||||||
)
|
callbacks=callbacks,
|
||||||
|
**self.llm_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
||||||
|
cast(List, prompts), {"callbacks": callbacks}
|
||||||
|
)
|
||||||
|
generations: List[List[Generation]] = []
|
||||||
|
for res in results:
|
||||||
|
if isinstance(res, BaseMessage):
|
||||||
|
generations.append([ChatGeneration(message=res)])
|
||||||
|
else:
|
||||||
|
generations.append([Generation(text=res)])
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
async def agenerate(
|
async def agenerate(
|
||||||
self,
|
self,
|
||||||
@ -114,12 +142,25 @@ class LLMChain(Chain):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Generate LLM result from inputs."""
|
"""Generate LLM result from inputs."""
|
||||||
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
|
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
|
||||||
return await self.llm.agenerate_prompt(
|
callbacks = run_manager.get_child() if run_manager else None
|
||||||
prompts,
|
if isinstance(self.llm, BaseLanguageModel):
|
||||||
stop,
|
return await self.llm.agenerate_prompt(
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
prompts,
|
||||||
**self.llm_kwargs,
|
stop,
|
||||||
)
|
callbacks=callbacks,
|
||||||
|
**self.llm_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
||||||
|
cast(List, prompts), {"callbacks": callbacks}
|
||||||
|
)
|
||||||
|
generations: List[List[Generation]] = []
|
||||||
|
for res in results:
|
||||||
|
if isinstance(res, BaseMessage):
|
||||||
|
generations.append([ChatGeneration(message=res)])
|
||||||
|
else:
|
||||||
|
generations.append([Generation(text=res)])
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
def prep_prompts(
|
def prep_prompts(
|
||||||
self,
|
self,
|
||||||
@ -343,3 +384,22 @@ class LLMChain(Chain):
|
|||||||
"""Create LLMChain from LLM and template."""
|
"""Create LLMChain from LLM and template."""
|
||||||
prompt_template = PromptTemplate.from_template(template)
|
prompt_template = PromptTemplate.from_template(template)
|
||||||
return cls(llm=llm, prompt=prompt_template)
|
return cls(llm=llm, prompt=prompt_template)
|
||||||
|
|
||||||
|
def _get_num_tokens(self, text: str) -> int:
|
||||||
|
return _get_language_model(self.llm).get_num_tokens(text)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
|
||||||
|
if isinstance(llm_like, BaseLanguageModel):
|
||||||
|
return llm_like
|
||||||
|
elif isinstance(llm_like, RunnableBinding):
|
||||||
|
return _get_language_model(llm_like.bound)
|
||||||
|
elif isinstance(llm_like, RunnableWithFallbacks):
|
||||||
|
return _get_language_model(llm_like.runnable)
|
||||||
|
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
|
||||||
|
return _get_language_model(llm_like.default)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to extract BaseLanguageModel from llm_like object of type "
|
||||||
|
f"{type(llm_like)}"
|
||||||
|
)
|
||||||
|
@ -31,9 +31,7 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
self.combine_documents_chain, StuffDocumentsChain
|
self.combine_documents_chain, StuffDocumentsChain
|
||||||
):
|
):
|
||||||
tokens = [
|
tokens = [
|
||||||
self.combine_documents_chain.llm_chain.llm.get_num_tokens(
|
self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)
|
||||||
doc.page_content
|
|
||||||
)
|
|
||||||
for doc in docs
|
for doc in docs
|
||||||
]
|
]
|
||||||
token_count = sum(tokens[:num_docs])
|
token_count = sum(tokens[:num_docs])
|
||||||
|
@ -36,9 +36,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
self.combine_documents_chain, StuffDocumentsChain
|
self.combine_documents_chain, StuffDocumentsChain
|
||||||
):
|
):
|
||||||
tokens = [
|
tokens = [
|
||||||
self.combine_documents_chain.llm_chain.llm.get_num_tokens(
|
self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)
|
||||||
doc.page_content
|
|
||||||
)
|
|
||||||
for doc in docs
|
for doc in docs
|
||||||
]
|
]
|
||||||
token_count = sum(tokens[:num_docs])
|
token_count = sum(tokens[:num_docs])
|
||||||
|
Loading…
Reference in New Issue
Block a user