community[minor]: Prem Templates (#22783)

This PR adds the feature add Prem Template feature in ChatPremAI.
Additionally it fixes a minor bug for API auth error when API passed
through arguments.
This commit is contained in:
Anindyadeep
2024-06-14 08:29:28 +05:30
committed by GitHub
parent 4160b700e6
commit c417803908
3 changed files with 159 additions and 34 deletions

View File

@@ -149,26 +149,49 @@ def _convert_delta_response_to_message_chunk(
def _messages_to_prompt_dict(
input_messages: List[BaseMessage],
) -> Tuple[Optional[str], List[Dict[str, str]]]:
template_id: Optional[str] = None,
) -> Tuple[Optional[str], List[Dict[str, Any]]]:
"""Converts a list of LangChain Messages into a simple dict
which is the message structure in Prem"""
system_prompt: Optional[str] = None
examples_and_messages: List[Dict[str, str]] = []
examples_and_messages: List[Dict[str, Any]] = []
for input_msg in input_messages:
if isinstance(input_msg, SystemMessage):
system_prompt = str(input_msg.content)
elif isinstance(input_msg, HumanMessage):
examples_and_messages.append(
{"role": "user", "content": str(input_msg.content)}
)
elif isinstance(input_msg, AIMessage):
examples_and_messages.append(
{"role": "assistant", "content": str(input_msg.content)}
)
else:
raise ChatPremAPIError("No such role explicitly exists")
if template_id is not None:
params: Dict[str, str] = {}
for input_msg in input_messages:
if isinstance(input_msg, SystemMessage):
system_prompt = str(input_msg.content)
else:
assert (input_msg.id is not None) and (input_msg.id != ""), ValueError(
"When using prompt template there should be id associated ",
"with each HumanMessage",
)
params[str(input_msg.id)] = str(input_msg.content)
examples_and_messages.append(
{"role": "user", "template_id": template_id, "params": params}
)
for input_msg in input_messages:
if isinstance(input_msg, AIMessage):
examples_and_messages.append(
{"role": "assistant", "content": str(input_msg.content)}
)
else:
for input_msg in input_messages:
if isinstance(input_msg, SystemMessage):
system_prompt = str(input_msg.content)
elif isinstance(input_msg, HumanMessage):
examples_and_messages.append(
{"role": "user", "content": str(input_msg.content)}
)
elif isinstance(input_msg, AIMessage):
examples_and_messages.append(
{"role": "assistant", "content": str(input_msg.content)}
)
else:
raise ChatPremAPIError("No such role explicitly exists")
return system_prompt, examples_and_messages
@@ -238,10 +261,14 @@ class ChatPremAI(BaseChatModel, BaseModel):
) from error
try:
premai_api_key = get_from_dict_or_env(
premai_api_key: Union[str, SecretStr] = get_from_dict_or_env(
values, "premai_api_key", "PREMAI_API_KEY"
)
values["client"] = Prem(api_key=premai_api_key)
values["client"] = Prem(
api_key=premai_api_key
if isinstance(premai_api_key, str)
else premai_api_key._secret_value
)
except Exception as error:
raise ValueError("Your API Key is incorrect. Please try again.") from error
return values
@@ -293,7 +320,12 @@ class ChatPremAI(BaseChatModel, BaseModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
if "template_id" in kwargs:
system_prompt, messages_to_pass = _messages_to_prompt_dict(
messages, template_id=kwargs["template_id"]
)
else:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
if system_prompt is not None and system_prompt != "":
kwargs["system_prompt"] = system_prompt
@@ -317,7 +349,12 @@ class ChatPremAI(BaseChatModel, BaseModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
if "template_id" in kwargs:
system_prompt, messages_to_pass = _messages_to_prompt_dict(
messages, template_id=kwargs["template_id"]
) # type: ignore
else:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
if stop is not None:
logger.warning("stop is not supported in langchain streaming")