mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 10:39:23 +00:00
Make BaseStringMessagePromptTemplate.from_template return type generic (#4523)
# Make BaseStringMessagePromptTemplate.from_template return type generic I use mypy to check type on my code that uses langchain. Currently after I load a prompt and convert it to a system prompt I have to explicitly cast it which is quite ugly (and not necessary): ``` prompt_template = load_prompt("prompt.yaml") system_prompt_template = cast( SystemMessagePromptTemplate, SystemMessagePromptTemplate.from_template(prompt_template.template), ) ``` With this PR, the code would simply be: ``` prompt_template = load_prompt("prompt.yaml") system_prompt_template = SystemMessagePromptTemplate.from_template(prompt_template.template) ``` Given how much langchain uses inheritance, I think this type hinting could be applied in a bunch more places, e.g. load_prompt also return a `FewShotPromptTemplate` or a `PromptTemplate` but without typing the type checkers aren't able to infer that. Let me know if you agree and I can take a look at implementing that as well. @hwchase17 - project lead DataLoaders - @eyurtsev
This commit is contained in:
parent
446b60d803
commit
97e7dc1502
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Sequence, Tuple, Type, Union
|
||||
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -58,12 +58,19 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
return [self.variable_name]
|
||||
|
||||
|
||||
MessagePromptTemplateT = TypeVar(
|
||||
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
|
||||
)
|
||||
|
||||
|
||||
class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
prompt: StringPromptTemplate
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_template(cls, template: str, **kwargs: Any) -> BaseMessagePromptTemplate:
|
||||
def from_template(
|
||||
cls: Type[MessagePromptTemplateT], template: str, **kwargs: Any
|
||||
) -> MessagePromptTemplateT:
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user