From d80c612c923a5aff616e9862fad1568204523933 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 2 Feb 2024 10:24:02 -0800 Subject: [PATCH] core[patch]: Message content as positional arg (#16921) --- libs/core/langchain_core/messages/base.py | 6 ++++++ libs/core/langchain_core/prompt_values.py | 4 ++-- libs/core/langchain_core/prompts/chat.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index b2fd76b6592..bf1c74f490d 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -28,6 +28,12 @@ class BaseMessage(Serializable): class Config: extra = Extra.allow + def __init__( + self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any + ) -> None: + """Pass in content as positional arg.""" + return super().__init__(content=content, **kwargs) + @classmethod def is_lc_serializable(cls) -> bool: """Return whether this class is serializable.""" diff --git a/libs/core/langchain_core/prompt_values.py b/libs/core/langchain_core/prompt_values.py index 4c599f9f6a0..18f62d9f2a5 100644 --- a/libs/core/langchain_core/prompt_values.py +++ b/libs/core/langchain_core/prompt_values.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Literal, Sequence +from typing import List, Literal, Sequence, cast from typing_extensions import TypedDict @@ -105,7 +105,7 @@ class ImagePromptValue(PromptValue): def to_messages(self) -> List[BaseMessage]: """Return prompt as messages.""" - return [HumanMessage(content=[self.image_url])] + return [HumanMessage(content=[cast(dict, self.image_url)])] class ChatPromptValueConcrete(ChatPromptValue): diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 4fbd0666b16..105daba2dfb 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -446,7 +446,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): content=text, additional_kwargs=self.additional_kwargs ) else: - content = [] + content: List = [] for prompt in self.prompt: inputs = {var: kwargs[var] for var in prompt.input_variables} if isinstance(prompt, StringPromptTemplate):