mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 06:33:20 +00:00
Type LLMChain.llm as runnable (#12385)
This commit is contained in:
parent
224ec0cfd3
commit
a8c68d4ffa
@ -150,7 +150,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
return self.llm_chain._get_num_tokens(prompt)
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
|
@ -284,7 +284,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
self.combine_docs_chain, StuffDocumentsChain
|
||||
):
|
||||
tokens = [
|
||||
self.combine_docs_chain.llm_chain.llm.get_num_tokens(doc.page_content)
|
||||
self.combine_docs_chain.llm_chain._get_num_tokens(doc.page_content)
|
||||
for doc in docs
|
||||
]
|
||||
token_count = sum(tokens[:num_docs])
|
||||
|
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
@ -17,12 +17,25 @@ from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.pydantic_v1 import Extra, Field
|
||||
from langchain.schema import (
|
||||
BaseLLMOutputParser,
|
||||
BaseMessage,
|
||||
BasePromptTemplate,
|
||||
ChatGeneration,
|
||||
Generation,
|
||||
LLMResult,
|
||||
PromptValue,
|
||||
StrOutputParser,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.language_model import (
|
||||
BaseLanguageModel,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain.schema.runnable import (
|
||||
Runnable,
|
||||
RunnableBinding,
|
||||
RunnableBranch,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from langchain.schema.runnable.configurable import DynamicRunnable
|
||||
from langchain.utils.input import get_colored_text
|
||||
|
||||
|
||||
@ -48,7 +61,9 @@ class LLMChain(Chain):
|
||||
|
||||
prompt: BasePromptTemplate
|
||||
"""Prompt object to use."""
|
||||
llm: BaseLanguageModel
|
||||
llm: Union[
|
||||
Runnable[LanguageModelInput, str], Runnable[LanguageModelInput, BaseMessage]
|
||||
]
|
||||
"""Language model to call."""
|
||||
output_key: str = "text" #: :meta private:
|
||||
output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser)
|
||||
@ -100,12 +115,25 @@ class LLMChain(Chain):
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
return self.llm.generate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
callbacks = run_manager.get_child() if run_manager else None
|
||||
if isinstance(self.llm, BaseLanguageModel):
|
||||
return self.llm.generate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=callbacks,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
else:
|
||||
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
||||
cast(List, prompts), {"callbacks": callbacks}
|
||||
)
|
||||
generations: List[List[Generation]] = []
|
||||
for res in results:
|
||||
if isinstance(res, BaseMessage):
|
||||
generations.append([ChatGeneration(message=res)])
|
||||
else:
|
||||
generations.append([Generation(text=res)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
@ -114,12 +142,25 @@ class LLMChain(Chain):
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
|
||||
return await self.llm.agenerate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
callbacks = run_manager.get_child() if run_manager else None
|
||||
if isinstance(self.llm, BaseLanguageModel):
|
||||
return await self.llm.agenerate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=callbacks,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
else:
|
||||
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
||||
cast(List, prompts), {"callbacks": callbacks}
|
||||
)
|
||||
generations: List[List[Generation]] = []
|
||||
for res in results:
|
||||
if isinstance(res, BaseMessage):
|
||||
generations.append([ChatGeneration(message=res)])
|
||||
else:
|
||||
generations.append([Generation(text=res)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def prep_prompts(
|
||||
self,
|
||||
@ -343,3 +384,22 @@ class LLMChain(Chain):
|
||||
"""Create LLMChain from LLM and template."""
|
||||
prompt_template = PromptTemplate.from_template(template)
|
||||
return cls(llm=llm, prompt=prompt_template)
|
||||
|
||||
def _get_num_tokens(self, text: str) -> int:
|
||||
return _get_language_model(self.llm).get_num_tokens(text)
|
||||
|
||||
|
||||
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
|
||||
if isinstance(llm_like, BaseLanguageModel):
|
||||
return llm_like
|
||||
elif isinstance(llm_like, RunnableBinding):
|
||||
return _get_language_model(llm_like.bound)
|
||||
elif isinstance(llm_like, RunnableWithFallbacks):
|
||||
return _get_language_model(llm_like.runnable)
|
||||
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
|
||||
return _get_language_model(llm_like.default)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to extract BaseLanguageModel from llm_like object of type "
|
||||
f"{type(llm_like)}"
|
||||
)
|
||||
|
@ -31,9 +31,7 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
self.combine_documents_chain, StuffDocumentsChain
|
||||
):
|
||||
tokens = [
|
||||
self.combine_documents_chain.llm_chain.llm.get_num_tokens(
|
||||
doc.page_content
|
||||
)
|
||||
self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)
|
||||
for doc in docs
|
||||
]
|
||||
token_count = sum(tokens[:num_docs])
|
||||
|
@ -36,9 +36,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
self.combine_documents_chain, StuffDocumentsChain
|
||||
):
|
||||
tokens = [
|
||||
self.combine_documents_chain.llm_chain.llm.get_num_tokens(
|
||||
doc.page_content
|
||||
)
|
||||
self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)
|
||||
for doc in docs
|
||||
]
|
||||
token_count = sum(tokens[:num_docs])
|
||||
|
Loading…
Reference in New Issue
Block a user