mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-23 20:23:59 +00:00
stash
This commit is contained in:
@@ -6,7 +6,6 @@ from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
@@ -64,7 +63,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
messages = messages_to_dict(self.messages)
|
||||
_message = _message_to_dict(message)
|
||||
_message = message.dict()
|
||||
messages.append(_message)
|
||||
|
||||
try:
|
||||
|
||||
@@ -7,7 +7,6 @@ from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -82,7 +81,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
self.collection.insert_one(
|
||||
{
|
||||
"SessionId": self.session_id,
|
||||
"History": json.dumps(_message_to_dict(message)),
|
||||
"History": json.dumps(message.dict()),
|
||||
}
|
||||
)
|
||||
except errors.WriteError as err:
|
||||
|
||||
@@ -7,7 +7,6 @@ from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -68,9 +67,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
|
||||
query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format(
|
||||
sql.Identifier(self.table_name)
|
||||
)
|
||||
self.cursor.execute(
|
||||
query, (self.session_id, json.dumps(_message_to_dict(message)))
|
||||
)
|
||||
self.cursor.execute(query, (self.session_id, json.dumps(message.dict())))
|
||||
self.connection.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
|
||||
@@ -7,7 +7,6 @@ from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -60,7 +59,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory):
|
||||
|
||||
def append(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in Redis"""
|
||||
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message)))
|
||||
self.redis_client.lpush(self.key, json.dumps(message.dict()))
|
||||
if self.ttl:
|
||||
self.redis_client.expire(self.key, self.ttl)
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -69,7 +68,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
|
||||
def append(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in db"""
|
||||
with self.Session() as session:
|
||||
jsonstr = json.dumps(_message_to_dict(message))
|
||||
jsonstr = json.dumps(message.dict())
|
||||
session.add(self.Message(session_id=self.session_id, message=jsonstr))
|
||||
session.commit()
|
||||
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import yaml
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.memory.buffer import get_buffer_string
|
||||
@@ -87,6 +87,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
def input_variables(self) -> List[str]:
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _type(self) -> str:
|
||||
"""The type of MessagePromptTemplate."""
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
result = super().dict(*args, **kwargs)
|
||||
result["_type"] = self._type
|
||||
return result
|
||||
|
||||
|
||||
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
role: str
|
||||
@@ -97,14 +107,22 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
content=text, role=self.role, additional_kwargs=self.additional_kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""The type of MessagePromptTemplate."""
|
||||
return "chat-message-prompt-template"
|
||||
|
||||
|
||||
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
role: str = "human"
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
text = self.prompt.format(**kwargs)
|
||||
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""The type of MessagePromptTemplate."""
|
||||
return "human-message-prompt-template"
|
||||
|
||||
|
||||
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
role: str = "ai"
|
||||
@@ -113,6 +131,11 @@ class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
text = self.prompt.format(**kwargs)
|
||||
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""The type of MessagePromptTemplate."""
|
||||
return "ai-message-prompt-template"
|
||||
|
||||
|
||||
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
role: str = "system"
|
||||
@@ -121,6 +144,11 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
text = self.prompt.format(**kwargs)
|
||||
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""The type of MessagePromptTemplate."""
|
||||
return "system-message-prompt-template"
|
||||
|
||||
|
||||
class ChatPromptValue(PromptValue):
|
||||
messages: List[BaseMessage]
|
||||
@@ -149,15 +177,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
|
||||
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
input_variables: List[str]
|
||||
messages: List[
|
||||
Union[
|
||||
BaseMessagePromptTemplate,
|
||||
BaseMessage,
|
||||
AIMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
]
|
||||
]
|
||||
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||
|
||||
@classmethod
|
||||
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
|
||||
@@ -224,19 +244,17 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
return "chatPrompt"
|
||||
return "chat_prompt"
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
prompt_dict = json.loads(self.json())
|
||||
|
||||
prompt_dict["_type"] = self._prompt_type
|
||||
for i, message in enumerate(self.messages):
|
||||
if isinstance(message, SystemMessagePromptTemplate):
|
||||
prompt_dict["messages"][i]["prompt"]["role"] = "system"
|
||||
elif isinstance(message, HumanMessagePromptTemplate):
|
||||
prompt_dict["messages"][i]["prompt"]["role"] = "human"
|
||||
elif isinstance(message, AIMessagePromptTemplate):
|
||||
prompt_dict["messages"][i]["prompt"]["role"] = "ai"
|
||||
prompt_dict = super().dict(**kwargs)
|
||||
# for i, message in enumerate(self.messages):
|
||||
# if isinstance(message, SystemMessagePromptTemplate):
|
||||
# prompt_dict["messages"][i]["prompt"]["role"] = "system"
|
||||
# elif isinstance(message, HumanMessagePromptTemplate):
|
||||
# prompt_dict["messages"][i]["prompt"]["role"] = "human"
|
||||
# elif isinstance(message, AIMessagePromptTemplate):
|
||||
# prompt_dict["messages"][i]["prompt"]["role"] = "ai"
|
||||
|
||||
return prompt_dict
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from langchain.prompts.chat import (
|
||||
)
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import message_from_dict
|
||||
from langchain.utilities.loading import try_load_from_hub
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||
@@ -129,14 +130,21 @@ def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
|
||||
|
||||
messages = []
|
||||
for message in config["messages"]:
|
||||
role = message["prompt"].get("role")
|
||||
template = message["prompt"]["template"]
|
||||
if role == "human":
|
||||
messages.append(HumanMessagePromptTemplate.from_template(template))
|
||||
elif role == "ai":
|
||||
messages.append(AIMessagePromptTemplate.from_template(template))
|
||||
_type = message.pop("_type")
|
||||
if _type == "human-message-prompt-template":
|
||||
prompt = load_prompt_from_config(message.pop("prompt"))
|
||||
_message = HumanMessagePromptTemplate(**{"prompt": prompt, **message})
|
||||
elif _type == "ai-message-prompt-template":
|
||||
prompt = load_prompt_from_config(message.pop("prompt"))
|
||||
_message = AIMessagePromptTemplate(**{"prompt": prompt, **message})
|
||||
elif _type == "system-message-prompt-template":
|
||||
prompt = load_prompt_from_config(message.pop("prompt"))
|
||||
_message = SystemMessagePromptTemplate(**{"prompt": prompt, **message})
|
||||
elif _type == "base-message":
|
||||
_message = message_from_dict(message)
|
||||
else: # role == system
|
||||
messages.append(SystemMessagePromptTemplate.from_template(template))
|
||||
raise ValueError
|
||||
messages.append(_message)
|
||||
|
||||
return ChatPromptTemplate.from_messages(messages)
|
||||
|
||||
@@ -185,7 +193,7 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
||||
|
||||
type_to_loader_dict = {
|
||||
"prompt": _load_prompt,
|
||||
"chatPrompt": _load_chat_prompt,
|
||||
"chat_prompt": _load_chat_prompt,
|
||||
"few_shot": _load_few_shot_prompt,
|
||||
# "few_shot_with_templates": _load_few_shot_with_templates_prompt,
|
||||
}
|
||||
|
||||
@@ -75,6 +75,13 @@ class BaseMessage(BaseModel):
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
return {
|
||||
"type": self.type,
|
||||
"data": super().dict(*args, **kwargs),
|
||||
"_type": "base-message",
|
||||
}
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the human."""
|
||||
@@ -118,15 +125,11 @@ class ChatMessage(BaseMessage):
|
||||
return "chat"
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> dict:
|
||||
return {"type": message.type, "data": message.dict()}
|
||||
|
||||
|
||||
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
|
||||
return [_message_to_dict(m) for m in messages]
|
||||
return [m.dict() for m in messages]
|
||||
|
||||
|
||||
def _message_from_dict(message: dict) -> BaseMessage:
|
||||
def message_from_dict(message: dict) -> BaseMessage:
|
||||
_type = message["type"]
|
||||
if _type == "human":
|
||||
return HumanMessage(**message["data"])
|
||||
@@ -141,7 +144,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
||||
|
||||
|
||||
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||
return [_message_from_dict(m) for m in messages]
|
||||
return [message_from_dict(m) for m in messages]
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
|
||||
@@ -5,9 +5,16 @@ from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from langchain.prompts.chat import (
|
||||
AIMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.loading import load_prompt
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -71,6 +78,23 @@ def test_saving_loading_round_trip(tmp_path: Path) -> None:
|
||||
assert loaded_prompt == few_shot_prompt
|
||||
|
||||
|
||||
def test_saving_chat_loading_round_trip(tmp_path: Path) -> None:
|
||||
"""Test equality when saving and loading a chat prompt."""
|
||||
message_list = [
|
||||
[HumanMessage(content="hi"), HumanMessagePromptTemplate.from_template("{foo}")],
|
||||
[AIMessage(content="hi"), AIMessagePromptTemplate.from_template("{foo}")],
|
||||
[
|
||||
SystemMessage(content="hi"),
|
||||
SystemMessagePromptTemplate.from_template("{foo}"),
|
||||
],
|
||||
]
|
||||
for messages in message_list:
|
||||
simple_prompt = ChatPromptTemplate.from_messages(messages)
|
||||
simple_prompt.save(file_path=tmp_path / "prompt.yaml")
|
||||
loaded_prompt = load_prompt(tmp_path / "prompt.yaml")
|
||||
assert loaded_prompt == simple_prompt
|
||||
|
||||
|
||||
def test_loading_with_template_as_file() -> None:
|
||||
"""Test loading when the template is a file."""
|
||||
with change_directory():
|
||||
|
||||
Reference in New Issue
Block a user