From a8c68d4ffa8564d9e8e67e53c84e8b3582a8084c Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 27 Oct 2023 11:52:01 -0700 Subject: [PATCH] Type LLMChain.llm as runnable (#12385) --- .../chains/combine_documents/stuff.py | 2 +- .../chains/conversational_retrieval/base.py | 2 +- libs/langchain/langchain/chains/llm.py | 90 +++++++++++++++---- .../chains/qa_with_sources/retrieval.py | 4 +- .../chains/qa_with_sources/vector_db.py | 4 +- 5 files changed, 79 insertions(+), 23 deletions(-) diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index e5b73a17f2b..063efe6244d 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -150,7 +150,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): """ inputs = self._get_inputs(docs, **kwargs) prompt = self.llm_chain.prompt.format(**inputs) - return self.llm_chain.llm.get_num_tokens(prompt) + return self.llm_chain._get_num_tokens(prompt) def combine_docs( self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index 5347b29a610..6812fb3907e 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -284,7 +284,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): self.combine_docs_chain, StuffDocumentsChain ): tokens = [ - self.combine_docs_chain.llm_chain.llm.get_num_tokens(doc.page_content) + self.combine_docs_chain.llm_chain._get_num_tokens(doc.page_content) for doc in docs ] token_count = sum(tokens[:num_docs]) diff --git a/libs/langchain/langchain/chains/llm.py b/libs/langchain/langchain/chains/llm.py index 1f251a677ba..33555f4d86a 100644 --- a/libs/langchain/langchain/chains/llm.py +++ b/libs/langchain/langchain/chains/llm.py @@ -2,7 +2,7 @@ from __future__ import annotations import warnings -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from langchain.callbacks.manager import ( AsyncCallbackManager, @@ -17,12 +17,25 @@ from langchain.prompts.prompt import PromptTemplate from langchain.pydantic_v1 import Extra, Field from langchain.schema import ( BaseLLMOutputParser, + BaseMessage, BasePromptTemplate, + ChatGeneration, + Generation, LLMResult, PromptValue, StrOutputParser, ) -from langchain.schema.language_model import BaseLanguageModel +from langchain.schema.language_model import ( + BaseLanguageModel, + LanguageModelInput, +) +from langchain.schema.runnable import ( + Runnable, + RunnableBinding, + RunnableBranch, + RunnableWithFallbacks, +) +from langchain.schema.runnable.configurable import DynamicRunnable from langchain.utils.input import get_colored_text @@ -48,7 +61,9 @@ class LLMChain(Chain): prompt: BasePromptTemplate """Prompt object to use.""" - llm: BaseLanguageModel + llm: Union[ + Runnable[LanguageModelInput, str], Runnable[LanguageModelInput, BaseMessage] + ] """Language model to call.""" output_key: str = "text" #: :meta private: output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser) @@ -100,12 +115,25 @@ class LLMChain(Chain): ) -> LLMResult: """Generate LLM result from inputs.""" prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) - return self.llm.generate_prompt( - prompts, - stop, - callbacks=run_manager.get_child() if run_manager else None, - **self.llm_kwargs, - ) + callbacks = run_manager.get_child() if run_manager else None + if isinstance(self.llm, BaseLanguageModel): + return self.llm.generate_prompt( + prompts, + stop, + callbacks=callbacks, + **self.llm_kwargs, + ) + else: + results = self.llm.bind(stop=stop, **self.llm_kwargs).batch( + cast(List, prompts), {"callbacks": callbacks} + ) + generations: List[List[Generation]] = [] + for res in results: + if isinstance(res, BaseMessage): + generations.append([ChatGeneration(message=res)]) + else: + generations.append([Generation(text=res)]) + return LLMResult(generations=generations) async def agenerate( self, @@ -114,12 +142,25 @@ class LLMChain(Chain): ) -> LLMResult: """Generate LLM result from inputs.""" prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager) - return await self.llm.agenerate_prompt( - prompts, - stop, - callbacks=run_manager.get_child() if run_manager else None, - **self.llm_kwargs, - ) + callbacks = run_manager.get_child() if run_manager else None + if isinstance(self.llm, BaseLanguageModel): + return await self.llm.agenerate_prompt( + prompts, + stop, + callbacks=callbacks, + **self.llm_kwargs, + ) + else: + results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch( + cast(List, prompts), {"callbacks": callbacks} + ) + generations: List[List[Generation]] = [] + for res in results: + if isinstance(res, BaseMessage): + generations.append([ChatGeneration(message=res)]) + else: + generations.append([Generation(text=res)]) + return LLMResult(generations=generations) def prep_prompts( self, @@ -343,3 +384,22 @@ class LLMChain(Chain): """Create LLMChain from LLM and template.""" prompt_template = PromptTemplate.from_template(template) return cls(llm=llm, prompt=prompt_template) + + def _get_num_tokens(self, text: str) -> int: + return _get_language_model(self.llm).get_num_tokens(text) + + +def _get_language_model(llm_like: Runnable) -> BaseLanguageModel: + if isinstance(llm_like, BaseLanguageModel): + return llm_like + elif isinstance(llm_like, RunnableBinding): + return _get_language_model(llm_like.bound) + elif isinstance(llm_like, RunnableWithFallbacks): + return _get_language_model(llm_like.runnable) + elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)): + return _get_language_model(llm_like.default) + else: + raise ValueError( + f"Unable to extract BaseLanguageModel from llm_like object of type " + f"{type(llm_like)}" + ) diff --git a/libs/langchain/langchain/chains/qa_with_sources/retrieval.py b/libs/langchain/langchain/chains/qa_with_sources/retrieval.py index 80018950d96..d47c43e51cf 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/retrieval.py +++ b/libs/langchain/langchain/chains/qa_with_sources/retrieval.py @@ -31,9 +31,7 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain): self.combine_documents_chain, StuffDocumentsChain ): tokens = [ - self.combine_documents_chain.llm_chain.llm.get_num_tokens( - doc.page_content - ) + self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content) for doc in docs ] token_count = sum(tokens[:num_docs]) diff --git a/libs/langchain/langchain/chains/qa_with_sources/vector_db.py b/libs/langchain/langchain/chains/qa_with_sources/vector_db.py index 44659d91703..8bb432ce88f 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/vector_db.py +++ b/libs/langchain/langchain/chains/qa_with_sources/vector_db.py @@ -36,9 +36,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain): self.combine_documents_chain, StuffDocumentsChain ): tokens = [ - self.combine_documents_chain.llm_chain.llm.get_num_tokens( - doc.page_content - ) + self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content) for doc in docs ] token_count = sum(tokens[:num_docs])