diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index f4f4850a0ac..05097e0d072 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -419,7 +419,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): llm_kwargs = prompts[0].llm_kwargs for prompt in prompts: if prompt.llm_kwargs != llm_kwargs: - raise ValueError("All prompt kwargs must be the same when calling in batch") + raise ValueError( + "All prompt kwargs must be the same when calling in batch" + ) kwargs = {**llm_kwargs, **kwargs} prompt_messages = [p.to_messages() for p in prompts] return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs) @@ -434,7 +436,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): llm_kwargs = prompts[0].llm_kwargs for prompt in prompts: if prompt.llm_kwargs != llm_kwargs: - raise ValueError("All prompt kwargs must be the same when calling in batch") + raise ValueError( + "All prompt kwargs must be the same when calling in batch" + ) kwargs = {**llm_kwargs, **kwargs} prompt_messages = [p.to_messages() for p in prompts] return await self.agenerate( diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 95c2b8d3607..7654823c9b2 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -470,7 +470,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): llm_kwargs = prompts[0].llm_kwargs for prompt in prompts: if prompt.llm_kwargs != llm_kwargs: - raise ValueError("All prompt kwargs must be the same when calling in batch") + raise ValueError( + "All prompt kwargs must be the same when calling in batch" + ) kwargs = {**llm_kwargs, **kwargs} prompt_strings = [p.to_string() for p in prompts] return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs) @@ -485,7 +487,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): llm_kwargs = prompts[0].llm_kwargs for prompt in prompts: if prompt.llm_kwargs != llm_kwargs: - raise ValueError("All prompt kwargs must be the same when calling in batch") + raise ValueError( + "All prompt kwargs must be the same when calling in batch" + ) kwargs = {**llm_kwargs, **kwargs} prompt_strings = [p.to_string() for p in prompts] return await self.agenerate( diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index ef163bd6377..f75927a1120 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -298,6 +298,7 @@ class ChatPromptValue(PromptValue): class BaseChatPromptTemplate(BasePromptTemplate, ABC): """Base class for chat prompt templates.""" + llm_kwargs: dict = Field(default_factory=dict) @property @@ -488,9 +489,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): @classmethod def from_messages( - cls, - messages: Sequence[MessageLikeRepresentation], - **kwargs: Any + cls, messages: Sequence[MessageLikeRepresentation], **kwargs: Any ) -> ChatPromptTemplate: """Create a chat prompt template from a variety of message formats. diff --git a/libs/langchain/langchain/schema/prompt.py b/libs/langchain/langchain/schema/prompt.py index 10d5759ed70..340f5933b6e 100644 --- a/libs/langchain/langchain/schema/prompt.py +++ b/libs/langchain/langchain/schema/prompt.py @@ -3,9 +3,8 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import List -from langchain.pydantic_v1 import Field - from langchain.load.serializable import Serializable +from langchain.pydantic_v1 import Field from langchain.schema.messages import BaseMessage @@ -15,6 +14,7 @@ class PromptValue(Serializable, ABC): PromptValues can be converted to both LLM (pure text-generation) inputs and ChatModel inputs. """ + llm_kwargs: dict = Field(default_factory=dict) @property