langchain/libs/community/langchain_community/chat_models/human.py
Erick Friis c2a3021bb0
multiple: pydantic 2 compatibility, v0.3 (#26443)
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com>
Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com>
Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com>
Co-authored-by: ZhangShenao <15201440436@163.com>
Co-authored-by: Friso H. Kingma <fhkingma@gmail.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
Co-authored-by: Morgante Pell <morgantep@google.com>
2024-09-13 14:38:45 -07:00

112 lines
3.6 KiB
Python

"""ChatModel wrapper which returns user input as the response.."""
from io import StringIO
from typing import Any, Callable, Dict, List, Mapping, Optional
import yaml
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
BaseMessage,
HumanMessage,
_message_from_dict,
messages_to_dict,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from pydantic import Field
from langchain_community.llms.utils import enforce_stop_tokens
def _display_messages(messages: List[BaseMessage]) -> None:
dict_messages = messages_to_dict(messages)
for message in dict_messages:
yaml_string = yaml.dump(
message,
default_flow_style=False,
sort_keys=False,
allow_unicode=True,
width=10000,
line_break=None,
)
print("\n", "======= start of message =======", "\n\n") # noqa: T201
print(yaml_string) # noqa: T201
print("======= end of message =======", "\n\n") # noqa: T201
def _collect_yaml_input(
messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> BaseMessage:
"""Collects and returns user input as a single string."""
lines = []
while True:
line = input()
if not line.strip():
break
if stop and any(seq in line for seq in stop):
break
lines.append(line)
yaml_string = "\n".join(lines)
# Try to parse the input string as YAML
try:
message = _message_from_dict(yaml.safe_load(StringIO(yaml_string)))
if message is None:
return HumanMessage(content="")
if stop:
if isinstance(message.content, str):
message.content = enforce_stop_tokens(message.content, stop)
else:
raise ValueError("Cannot use when output is not a string.")
return message
except yaml.YAMLError:
raise ValueError("Invalid YAML string entered.")
except ValueError:
raise ValueError("Invalid message entered.")
class HumanInputChatModel(BaseChatModel):
"""ChatModel which returns user input as the response."""
input_func: Callable = Field(default_factory=lambda: _collect_yaml_input)
message_func: Callable = Field(default_factory=lambda: _display_messages)
separator: str = "\n"
input_kwargs: Mapping[str, Any] = {}
message_kwargs: Mapping[str, Any] = {}
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
"input_func": self.input_func.__name__,
"message_func": self.message_func.__name__,
}
@property
def _llm_type(self) -> str:
"""Returns the type of LLM."""
return "human-input-chat-model"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""
Displays the messages to the user and returns their input as a response.
Args:
messages (List[BaseMessage]): The messages to be displayed to the user.
stop (Optional[List[str]]): A list of stop strings.
run_manager (Optional[CallbackManagerForLLMRun]): Currently not used.
Returns:
ChatResult: The user's input as a response.
"""
self.message_func(messages, **self.message_kwargs)
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
return ChatResult(generations=[ChatGeneration(message=user_input)])