Type LLMChain.llm as runnable (#12385)

This commit is contained in:
Bagatur 2023-10-27 11:52:01 -07:00 committed by GitHub
parent 224ec0cfd3
commit a8c68d4ffa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 23 deletions

View File

@ -150,7 +150,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
""" """
inputs = self._get_inputs(docs, **kwargs) inputs = self._get_inputs(docs, **kwargs)
prompt = self.llm_chain.prompt.format(**inputs) prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt) return self.llm_chain._get_num_tokens(prompt)
def combine_docs( def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any

View File

@ -284,7 +284,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
self.combine_docs_chain, StuffDocumentsChain self.combine_docs_chain, StuffDocumentsChain
): ):
tokens = [ tokens = [
self.combine_docs_chain.llm_chain.llm.get_num_tokens(doc.page_content) self.combine_docs_chain.llm_chain._get_num_tokens(doc.page_content)
for doc in docs for doc in docs
] ]
token_count = sum(tokens[:num_docs]) token_count = sum(tokens[:num_docs])

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManager, AsyncCallbackManager,
@ -17,12 +17,25 @@ from langchain.prompts.prompt import PromptTemplate
from langchain.pydantic_v1 import Extra, Field from langchain.pydantic_v1 import Extra, Field
from langchain.schema import ( from langchain.schema import (
BaseLLMOutputParser, BaseLLMOutputParser,
BaseMessage,
BasePromptTemplate, BasePromptTemplate,
ChatGeneration,
Generation,
LLMResult, LLMResult,
PromptValue, PromptValue,
StrOutputParser, StrOutputParser,
) )
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import (
BaseLanguageModel,
LanguageModelInput,
)
from langchain.schema.runnable import (
Runnable,
RunnableBinding,
RunnableBranch,
RunnableWithFallbacks,
)
from langchain.schema.runnable.configurable import DynamicRunnable
from langchain.utils.input import get_colored_text from langchain.utils.input import get_colored_text
@ -48,7 +61,9 @@ class LLMChain(Chain):
prompt: BasePromptTemplate prompt: BasePromptTemplate
"""Prompt object to use.""" """Prompt object to use."""
llm: BaseLanguageModel llm: Union[
Runnable[LanguageModelInput, str], Runnable[LanguageModelInput, BaseMessage]
]
"""Language model to call.""" """Language model to call."""
output_key: str = "text" #: :meta private: output_key: str = "text" #: :meta private:
output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser) output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser)
@ -100,12 +115,25 @@ class LLMChain(Chain):
) -> LLMResult: ) -> LLMResult:
"""Generate LLM result from inputs.""" """Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt( callbacks = run_manager.get_child() if run_manager else None
prompts, if isinstance(self.llm, BaseLanguageModel):
stop, return self.llm.generate_prompt(
callbacks=run_manager.get_child() if run_manager else None, prompts,
**self.llm_kwargs, stop,
) callbacks=callbacks,
**self.llm_kwargs,
)
else:
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
cast(List, prompts), {"callbacks": callbacks}
)
generations: List[List[Generation]] = []
for res in results:
if isinstance(res, BaseMessage):
generations.append([ChatGeneration(message=res)])
else:
generations.append([Generation(text=res)])
return LLMResult(generations=generations)
async def agenerate( async def agenerate(
self, self,
@ -114,12 +142,25 @@ class LLMChain(Chain):
) -> LLMResult: ) -> LLMResult:
"""Generate LLM result from inputs.""" """Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager) prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
return await self.llm.agenerate_prompt( callbacks = run_manager.get_child() if run_manager else None
prompts, if isinstance(self.llm, BaseLanguageModel):
stop, return await self.llm.agenerate_prompt(
callbacks=run_manager.get_child() if run_manager else None, prompts,
**self.llm_kwargs, stop,
) callbacks=callbacks,
**self.llm_kwargs,
)
else:
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
cast(List, prompts), {"callbacks": callbacks}
)
generations: List[List[Generation]] = []
for res in results:
if isinstance(res, BaseMessage):
generations.append([ChatGeneration(message=res)])
else:
generations.append([Generation(text=res)])
return LLMResult(generations=generations)
def prep_prompts( def prep_prompts(
self, self,
@ -343,3 +384,22 @@ class LLMChain(Chain):
"""Create LLMChain from LLM and template.""" """Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template) prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template) return cls(llm=llm, prompt=prompt_template)
def _get_num_tokens(self, text: str) -> int:
return _get_language_model(self.llm).get_num_tokens(text)
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
if isinstance(llm_like, BaseLanguageModel):
return llm_like
elif isinstance(llm_like, RunnableBinding):
return _get_language_model(llm_like.bound)
elif isinstance(llm_like, RunnableWithFallbacks):
return _get_language_model(llm_like.runnable)
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
return _get_language_model(llm_like.default)
else:
raise ValueError(
f"Unable to extract BaseLanguageModel from llm_like object of type "
f"{type(llm_like)}"
)

View File

@ -31,9 +31,7 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
self.combine_documents_chain, StuffDocumentsChain self.combine_documents_chain, StuffDocumentsChain
): ):
tokens = [ tokens = [
self.combine_documents_chain.llm_chain.llm.get_num_tokens( self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)
doc.page_content
)
for doc in docs for doc in docs
] ]
token_count = sum(tokens[:num_docs]) token_count = sum(tokens[:num_docs])

View File

@ -36,9 +36,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
self.combine_documents_chain, StuffDocumentsChain self.combine_documents_chain, StuffDocumentsChain
): ):
tokens = [ tokens = [
self.combine_documents_chain.llm_chain.llm.get_num_tokens( self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)
doc.page_content
)
for doc in docs for doc in docs
] ]
token_count = sum(tokens[:num_docs]) token_count = sum(tokens[:num_docs])