This commit is contained in:
अंkur गोswami 2025-07-28 16:41:39 -04:00 committed by GitHub
commit 3ddfd1f79d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -51,7 +51,12 @@ from langchain_core.outputs import (
ChatResult,
LLMResult,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableMap,
RunnablePassthrough,
)
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import (
convert_to_json_schema,
@ -714,6 +719,34 @@ class ChatHuggingFace(BaseChatModel):
)
yield generation_chunk
def batch(
self,
inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> list[BaseMessage]:
if not inputs:
return []
if _is_huggingface_pipeline(self.llm):
prompt_messages = [
self._convert_input(input_).to_messages() for input_ in inputs
]
llm_inputs = list(map(self._to_chat_prompt, prompt_messages))
llm_results = self.llm._generate(prompts=llm_inputs)
chat_result = self._to_chat_result(llm_results)
return [gen.message for gen in chat_result.generations]
return super().batch(
inputs=inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
def _to_chat_prompt(
self,
messages: list[BaseMessage],
@ -750,10 +783,10 @@ class ChatHuggingFace(BaseChatModel):
@staticmethod
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
chat_generations = []
for g in llm_result.generations[0]:
for g in llm_result.generations:
chat_generation = ChatGeneration(
message=AIMessage(content=g.text), generation_info=g.generation_info
message=AIMessage(content=g[0].text),
generation_info=g[0].generation_info,
)
chat_generations.append(chat_generation)