mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 20:16:52 +00:00
community[minor]: Prem Templates (#22783)
This PR adds the feature add Prem Template feature in ChatPremAI. Additionally it fixes a minor bug for API auth error when API passed through arguments.
This commit is contained in:
@@ -149,26 +149,49 @@ def _convert_delta_response_to_message_chunk(
|
||||
|
||||
def _messages_to_prompt_dict(
|
||||
input_messages: List[BaseMessage],
|
||||
) -> Tuple[Optional[str], List[Dict[str, str]]]:
|
||||
template_id: Optional[str] = None,
|
||||
) -> Tuple[Optional[str], List[Dict[str, Any]]]:
|
||||
"""Converts a list of LangChain Messages into a simple dict
|
||||
which is the message structure in Prem"""
|
||||
|
||||
system_prompt: Optional[str] = None
|
||||
examples_and_messages: List[Dict[str, str]] = []
|
||||
examples_and_messages: List[Dict[str, Any]] = []
|
||||
|
||||
for input_msg in input_messages:
|
||||
if isinstance(input_msg, SystemMessage):
|
||||
system_prompt = str(input_msg.content)
|
||||
elif isinstance(input_msg, HumanMessage):
|
||||
examples_and_messages.append(
|
||||
{"role": "user", "content": str(input_msg.content)}
|
||||
)
|
||||
elif isinstance(input_msg, AIMessage):
|
||||
examples_and_messages.append(
|
||||
{"role": "assistant", "content": str(input_msg.content)}
|
||||
)
|
||||
else:
|
||||
raise ChatPremAPIError("No such role explicitly exists")
|
||||
if template_id is not None:
|
||||
params: Dict[str, str] = {}
|
||||
for input_msg in input_messages:
|
||||
if isinstance(input_msg, SystemMessage):
|
||||
system_prompt = str(input_msg.content)
|
||||
else:
|
||||
assert (input_msg.id is not None) and (input_msg.id != ""), ValueError(
|
||||
"When using prompt template there should be id associated ",
|
||||
"with each HumanMessage",
|
||||
)
|
||||
params[str(input_msg.id)] = str(input_msg.content)
|
||||
|
||||
examples_and_messages.append(
|
||||
{"role": "user", "template_id": template_id, "params": params}
|
||||
)
|
||||
|
||||
for input_msg in input_messages:
|
||||
if isinstance(input_msg, AIMessage):
|
||||
examples_and_messages.append(
|
||||
{"role": "assistant", "content": str(input_msg.content)}
|
||||
)
|
||||
else:
|
||||
for input_msg in input_messages:
|
||||
if isinstance(input_msg, SystemMessage):
|
||||
system_prompt = str(input_msg.content)
|
||||
elif isinstance(input_msg, HumanMessage):
|
||||
examples_and_messages.append(
|
||||
{"role": "user", "content": str(input_msg.content)}
|
||||
)
|
||||
elif isinstance(input_msg, AIMessage):
|
||||
examples_and_messages.append(
|
||||
{"role": "assistant", "content": str(input_msg.content)}
|
||||
)
|
||||
else:
|
||||
raise ChatPremAPIError("No such role explicitly exists")
|
||||
return system_prompt, examples_and_messages
|
||||
|
||||
|
||||
@@ -238,10 +261,14 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
) from error
|
||||
|
||||
try:
|
||||
premai_api_key = get_from_dict_or_env(
|
||||
premai_api_key: Union[str, SecretStr] = get_from_dict_or_env(
|
||||
values, "premai_api_key", "PREMAI_API_KEY"
|
||||
)
|
||||
values["client"] = Prem(api_key=premai_api_key)
|
||||
values["client"] = Prem(
|
||||
api_key=premai_api_key
|
||||
if isinstance(premai_api_key, str)
|
||||
else premai_api_key._secret_value
|
||||
)
|
||||
except Exception as error:
|
||||
raise ValueError("Your API Key is incorrect. Please try again.") from error
|
||||
return values
|
||||
@@ -293,7 +320,12 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
|
||||
if "template_id" in kwargs:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(
|
||||
messages, template_id=kwargs["template_id"]
|
||||
)
|
||||
else:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
|
||||
|
||||
if system_prompt is not None and system_prompt != "":
|
||||
kwargs["system_prompt"] = system_prompt
|
||||
@@ -317,7 +349,12 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
|
||||
if "template_id" in kwargs:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(
|
||||
messages, template_id=kwargs["template_id"]
|
||||
) # type: ignore
|
||||
else:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
|
||||
|
||||
if stop is not None:
|
||||
logger.warning("stop is not supported in langchain streaming")
|
||||
|
Reference in New Issue
Block a user