From 0c8792d6fd41adada9549f15191f5d005a25847f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=A4=85=E0=A4=82kur=20=E0=A4=97=E0=A5=8Bswami?= Date: Sun, 8 Jun 2025 23:24:08 +0530 Subject: [PATCH 1/3] batching for huggingface pipelines --- .../chat_models/huggingface.py | 37 +++++++++++++++++-- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py index cceed0e1e5a..8fead435ab0 100644 --- a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py @@ -48,7 +48,7 @@ from langchain_core.outputs import ( ChatResult, LLMResult, ) -from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough +from langchain_core.runnables import Runnable, RunnableConfig, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import ( convert_to_json_schema, @@ -715,6 +715,36 @@ class ChatHuggingFace(BaseChatModel): ) yield generation_chunk + def batch( + self, + inputs: list[LanguageModelInput], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> list[BaseMessage]: + if not inputs: + return [] + + if _is_huggingface_pipeline(self.llm): + prompt_messages = [ + self._convert_input(input).to_messages() for input in inputs] + + llm_inputs = list(map(self._to_chat_prompt, prompt_messages)) + + llm_results = self.llm._generate( + prompts=llm_inputs + ) + chat_result = self._to_chat_result(llm_results) + return [gen.message for gen in chat_result.generations] + else: + return super().batch( + inputs=inputs, + config=config, + return_exceptions=return_exceptions, + **kwargs, + ) + def _to_chat_prompt( self, messages: list[BaseMessage], @@ -749,10 +779,9 @@ class ChatHuggingFace(BaseChatModel): @staticmethod def _to_chat_result(llm_result: LLMResult) -> ChatResult: chat_generations = [] - - for g in llm_result.generations[0]: + for g in llm_result.generations: chat_generation = ChatGeneration( - message=AIMessage(content=g.text), generation_info=g.generation_info + message=AIMessage(content=g[0].text), generation_info=g[0].generation_info ) chat_generations.append(chat_generation) From 96d648306f642cc395879d0d7a9c94092f6c11cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=A4=85=E0=A4=82kur=20=E0=A4=97=E0=A5=8Bswami?= Date: Mon, 9 Jun 2025 00:01:44 +0530 Subject: [PATCH 2/3] code formatting --- .../chat_models/huggingface.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py index 8fead435ab0..b176a1e2e14 100644 --- a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py @@ -48,7 +48,12 @@ from langchain_core.outputs import ( ChatResult, LLMResult, ) -from langchain_core.runnables import Runnable, RunnableConfig, RunnableMap, RunnablePassthrough +from langchain_core.runnables import ( + Runnable, + RunnableConfig, + RunnableMap, + RunnablePassthrough, +) from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import ( convert_to_json_schema, @@ -725,16 +730,15 @@ class ChatHuggingFace(BaseChatModel): ) -> list[BaseMessage]: if not inputs: return [] - + if _is_huggingface_pipeline(self.llm): prompt_messages = [ - self._convert_input(input).to_messages() for input in inputs] + self._convert_input(input).to_messages() for input in inputs + ] llm_inputs = list(map(self._to_chat_prompt, prompt_messages)) - - llm_results = self.llm._generate( - prompts=llm_inputs - ) + + llm_results = self.llm._generate(prompts=llm_inputs) chat_result = self._to_chat_result(llm_results) return [gen.message for gen in chat_result.generations] else: @@ -744,7 +748,7 @@ class ChatHuggingFace(BaseChatModel): return_exceptions=return_exceptions, **kwargs, ) - + def _to_chat_prompt( self, messages: list[BaseMessage], @@ -781,7 +785,8 @@ class ChatHuggingFace(BaseChatModel): chat_generations = [] for g in llm_result.generations: chat_generation = ChatGeneration( - message=AIMessage(content=g[0].text), generation_info=g[0].generation_info + message=AIMessage(content=g[0].text), + generation_info=g[0].generation_info, ) chat_generations.append(chat_generation) From 52129ab40c75cb0b7e6e73ac9401198e08fa2060 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Wed, 16 Jul 2025 11:47:00 -0400 Subject: [PATCH 3/3] lint --- .../chat_models/huggingface.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py index 96bd6c9a3b4..483b1a4d975 100644 --- a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py @@ -732,7 +732,7 @@ class ChatHuggingFace(BaseChatModel): if _is_huggingface_pipeline(self.llm): prompt_messages = [ - self._convert_input(input).to_messages() for input in inputs + self._convert_input(input_).to_messages() for input_ in inputs ] llm_inputs = list(map(self._to_chat_prompt, prompt_messages)) @@ -740,13 +740,12 @@ class ChatHuggingFace(BaseChatModel): llm_results = self.llm._generate(prompts=llm_inputs) chat_result = self._to_chat_result(llm_results) return [gen.message for gen in chat_result.generations] - else: - return super().batch( - inputs=inputs, - config=config, - return_exceptions=return_exceptions, - **kwargs, - ) + return super().batch( + inputs=inputs, + config=config, + return_exceptions=return_exceptions, + **kwargs, + ) def _to_chat_prompt( self,