add llm kwargs

This commit is contained in:
Harrison Chase
2023-08-27 11:11:13 -07:00
parent 86a3e51440
commit 60818aeaaa
4 changed files with 16 additions and 9 deletions

View File

@@ -419,7 +419,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
llm_kwargs = prompts[0].llm_kwargs
for prompt in prompts:
if prompt.llm_kwargs != llm_kwargs:
raise ValueError("All prompt kwargs must be the same when calling in batch")
raise ValueError(
"All prompt kwargs must be the same when calling in batch"
)
kwargs = {**llm_kwargs, **kwargs}
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
@@ -434,7 +436,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
llm_kwargs = prompts[0].llm_kwargs
for prompt in prompts:
if prompt.llm_kwargs != llm_kwargs:
raise ValueError("All prompt kwargs must be the same when calling in batch")
raise ValueError(
"All prompt kwargs must be the same when calling in batch"
)
kwargs = {**llm_kwargs, **kwargs}
prompt_messages = [p.to_messages() for p in prompts]
return await self.agenerate(

View File

@@ -470,7 +470,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
llm_kwargs = prompts[0].llm_kwargs
for prompt in prompts:
if prompt.llm_kwargs != llm_kwargs:
raise ValueError("All prompt kwargs must be the same when calling in batch")
raise ValueError(
"All prompt kwargs must be the same when calling in batch"
)
kwargs = {**llm_kwargs, **kwargs}
prompt_strings = [p.to_string() for p in prompts]
return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
@@ -485,7 +487,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
llm_kwargs = prompts[0].llm_kwargs
for prompt in prompts:
if prompt.llm_kwargs != llm_kwargs:
raise ValueError("All prompt kwargs must be the same when calling in batch")
raise ValueError(
"All prompt kwargs must be the same when calling in batch"
)
kwargs = {**llm_kwargs, **kwargs}
prompt_strings = [p.to_string() for p in prompts]
return await self.agenerate(

View File

@@ -298,6 +298,7 @@ class ChatPromptValue(PromptValue):
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates."""
llm_kwargs: dict = Field(default_factory=dict)
@property
@@ -488,9 +489,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
@classmethod
def from_messages(
cls,
messages: Sequence[MessageLikeRepresentation],
**kwargs: Any
cls, messages: Sequence[MessageLikeRepresentation], **kwargs: Any
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.

View File

@@ -3,9 +3,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List
from langchain.pydantic_v1 import Field
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field
from langchain.schema.messages import BaseMessage
@@ -15,6 +14,7 @@ class PromptValue(Serializable, ABC):
PromptValues can be converted to both LLM (pure text-generation) inputs and
ChatModel inputs.
"""
llm_kwargs: dict = Field(default_factory=dict)
@property