mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-22 07:27:45 +00:00
support kwargs (#5990)
This commit is contained in:
parent
b934677a81
commit
704d56e241
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Sequence, Set
|
||||
from typing import Any, List, Optional, Sequence, Set
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -36,6 +36,7 @@ class BaseLanguageModel(BaseModel, ABC):
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
@ -45,26 +46,39 @@ class BaseLanguageModel(BaseModel, ABC):
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
"""Predict text from text."""
|
||||
|
||||
@abstractmethod
|
||||
def predict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Predict message from messages."""
|
||||
|
||||
@abstractmethod
|
||||
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
"""Predict text from text."""
|
||||
|
||||
@abstractmethod
|
||||
async def apredict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Predict message from messages."""
|
||||
|
||||
|
@ -94,9 +94,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
@ -121,9 +122,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
|
@ -64,6 +64,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
|
||||
@ -82,7 +83,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
)
|
||||
try:
|
||||
results = [
|
||||
self._generate(m, stop=stop, run_manager=run_manager)
|
||||
self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if new_arg_supported
|
||||
else self._generate(m, stop=stop)
|
||||
for m in messages
|
||||
@ -103,6 +104,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
params = self.dict()
|
||||
@ -121,7 +123,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate(m, stop=stop, run_manager=run_manager)
|
||||
self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if new_arg_supported
|
||||
else self._agenerate(m, stop=stop)
|
||||
for m in messages
|
||||
@ -143,18 +145,22 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
return self.generate(prompt_messages, stop=stop, callbacks=callbacks)
|
||||
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
|
||||
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks)
|
||||
return await self.agenerate(
|
||||
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
@ -162,6 +168,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
|
||||
@ -171,6 +178,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
|
||||
@ -193,18 +201,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
result = await self.agenerate([messages], stop=stop, callbacks=callbacks)
|
||||
result = await self.agenerate(
|
||||
[messages], stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
generation = result.generations[0][0]
|
||||
if isinstance(generation, ChatGeneration):
|
||||
return generation.message
|
||||
else:
|
||||
raise ValueError("Unexpected generation type")
|
||||
|
||||
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
|
||||
return self.predict(message, stop=stop)
|
||||
def call_as_llm(
|
||||
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
return self.predict(message, stop=stop, **kwargs)
|
||||
|
||||
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
@ -213,30 +228,42 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
return result.content
|
||||
|
||||
def predict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
return self(messages, stop=_stop)
|
||||
return self(messages, stop=_stop, **kwargs)
|
||||
|
||||
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
result = await self._call_async([HumanMessage(content=text)], stop=_stop)
|
||||
result = await self._call_async(
|
||||
[HumanMessage(content=text)], stop=_stop, **kwargs
|
||||
)
|
||||
return result.content
|
||||
|
||||
async def apredict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
return await self._call_async(messages, stop=_stop)
|
||||
return await self._call_async(messages, stop=_stop, **kwargs)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
@ -261,8 +288,9 @@ class SimpleChatModel(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager)
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
@ -273,6 +301,7 @@ class SimpleChatModel(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Simpler interface."""
|
||||
|
||||
@ -281,6 +310,9 @@ class SimpleChatModel(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
func = partial(self._generate, messages, stop=stop, run_manager=run_manager)
|
||||
func = partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
@ -280,6 +280,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
@ -291,6 +292,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
candidate_count=self.n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _response_to_result(response, stop)
|
||||
@ -300,6 +302,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
|
@ -302,8 +302,10 @@ class ChatOpenAI(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
@ -348,8 +350,10 @@ class ChatOpenAI(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
|
@ -42,6 +42,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> ChatResult:
|
||||
"""Call ChatOpenAI generate and then call PromptLayer API to log the request."""
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||
@ -54,6 +55,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
response_dict, params = super()._create_message_dicts(
|
||||
[generation.message], stop
|
||||
)
|
||||
params = {**params, **kwargs}
|
||||
pl_request_id = promptlayer_api_request(
|
||||
"langchain.PromptLayerChatOpenAI",
|
||||
"langchain",
|
||||
@ -79,6 +81,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> ChatResult:
|
||||
"""Call ChatOpenAI agenerate and then call PromptLayer to log."""
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
||||
@ -91,6 +94,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
response_dict, params = super()._create_message_dicts(
|
||||
[generation.message], stop
|
||||
)
|
||||
params = {**params, **kwargs}
|
||||
pl_request_id = await promptlayer_api_request_async(
|
||||
"langchain.PromptLayerChatOpenAI.async",
|
||||
"langchain",
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Wrapper around Google VertexAI chat-based models."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
@ -93,6 +93,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate next turn in the conversation.
|
||||
|
||||
@ -119,7 +120,8 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
context = history.system_message.content if history.system_message else None
|
||||
chat = self.client.start_chat(context=context, **self._default_params)
|
||||
params = {**self._default_params, **kwargs}
|
||||
chat = self.client.start_chat(context=context, **params)
|
||||
for pair in history.history:
|
||||
chat._history.append((pair.question.content, pair.answer.content))
|
||||
response = chat.send_message(question.content, **self._default_params)
|
||||
@ -131,6 +133,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError(
|
||||
"""Vertex AI doesn't support async requests at the moment."""
|
||||
|
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, List, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, cast
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
@ -42,6 +42,7 @@ class JsonFormer(HuggingFacePipeline):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
jsonformer = import_jsonformer()
|
||||
from transformers import Text2TextGenerationPipeline
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Experimental implementation of RELLM wrapped LLM."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, cast
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
@ -47,6 +47,7 @@ class RELLM(HuggingFacePipeline):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
rellm = import_rellm()
|
||||
from transformers import Text2TextGenerationPipeline
|
||||
|
@ -112,6 +112,7 @@ class AI21(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to AI21's complete endpoint.
|
||||
|
||||
@ -140,10 +141,11 @@ class AI21(LLM):
|
||||
base_url = "https://api.ai21.com/studio/v1/experimental"
|
||||
else:
|
||||
base_url = "https://api.ai21.com/studio/v1"
|
||||
params = {**self._default_params, **kwargs}
|
||||
response = requests.post(
|
||||
url=f"{base_url}/{self.model}/complete",
|
||||
headers={"Authorization": f"Bearer {self.ai21_api_key}"},
|
||||
json={"prompt": prompt, "stopSequences": stop, **self._default_params},
|
||||
json={"prompt": prompt, "stopSequences": stop, **params},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
optional_detail = response.json().get("error")
|
||||
|
@ -206,6 +206,7 @@ class AlephAlpha(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Aleph Alpha's completion endpoint.
|
||||
|
||||
@ -232,6 +233,7 @@ class AlephAlpha(LLM):
|
||||
params["stop_sequences"] = self.stop_sequences
|
||||
else:
|
||||
params["stop_sequences"] = stop
|
||||
params = {**params, **kwargs}
|
||||
request = CompletionRequest(prompt=Prompt.from_text(prompt), **params)
|
||||
response = self.client.complete(model=self.model, request=request)
|
||||
text = response.completions[0].completion
|
||||
|
@ -162,6 +162,7 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
r"""Call out to Anthropic's completion endpoint.
|
||||
|
||||
@ -181,11 +182,12 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
|
||||
"""
|
||||
stop = self._get_anthropic_stop(stop)
|
||||
params = {**self._default_params, **kwargs}
|
||||
if self.streaming:
|
||||
stream_resp = self.client.completion_stream(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
stop_sequences=stop,
|
||||
**self._default_params,
|
||||
**params,
|
||||
)
|
||||
current_completion = ""
|
||||
for data in stream_resp:
|
||||
@ -197,7 +199,7 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
response = self.client.completion(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
stop_sequences=stop,
|
||||
**self._default_params,
|
||||
**params,
|
||||
)
|
||||
return response["completion"]
|
||||
|
||||
@ -206,14 +208,16 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Anthropic's completion endpoint asynchronously."""
|
||||
stop = self._get_anthropic_stop(stop)
|
||||
params = {**self._default_params, **kwargs}
|
||||
if self.streaming:
|
||||
stream_resp = await self.client.acompletion_stream(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
stop_sequences=stop,
|
||||
**self._default_params,
|
||||
**params,
|
||||
)
|
||||
current_completion = ""
|
||||
async for data in stream_resp:
|
||||
@ -225,7 +229,7 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
response = await self.client.acompletion(
|
||||
prompt=self._wrap_prompt(prompt),
|
||||
stop_sequences=stop,
|
||||
**self._default_params,
|
||||
**params,
|
||||
)
|
||||
return response["completion"]
|
||||
|
||||
|
@ -88,6 +88,7 @@ class Anyscale(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Anyscale Service endpoint.
|
||||
Args:
|
||||
|
@ -105,6 +105,7 @@ class Aviary(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Aviary
|
||||
Args:
|
||||
|
@ -87,6 +87,7 @@ class Banana(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to Banana endpoint."""
|
||||
try:
|
||||
@ -97,6 +98,7 @@ class Banana(LLM):
|
||||
"Please install it with `pip install banana-dev`."
|
||||
)
|
||||
params = self.model_kwargs or {}
|
||||
params = {**params, **kwargs}
|
||||
api_key = self.banana_api_key
|
||||
model_key = self.model_key
|
||||
model_inputs = {
|
||||
|
@ -113,6 +113,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
|
||||
@ -122,6 +123,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
|
||||
@ -130,24 +132,29 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
return self.generate(prompt_strings, stop=stop, callbacks=callbacks)
|
||||
return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
|
||||
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks)
|
||||
return await self.agenerate(
|
||||
prompt_strings, stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# If string is passed in directly no errors will be raised but outputs will
|
||||
@ -183,9 +190,11 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
)
|
||||
try:
|
||||
output = (
|
||||
self._generate(prompts, stop=stop, run_manager=run_manager)
|
||||
self._generate(
|
||||
prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._generate(prompts, stop=stop)
|
||||
else self._generate(prompts, stop=stop, **kwargs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
@ -202,9 +211,11 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
)
|
||||
try:
|
||||
new_results = (
|
||||
self._generate(missing_prompts, stop=stop, run_manager=run_manager)
|
||||
self._generate(
|
||||
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._generate(missing_prompts, stop=stop)
|
||||
else self._generate(missing_prompts, stop=stop, **kwargs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
@ -227,6 +238,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
params = self.dict()
|
||||
@ -255,9 +267,11 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
)
|
||||
try:
|
||||
output = (
|
||||
await self._agenerate(prompts, stop=stop, run_manager=run_manager)
|
||||
await self._agenerate(
|
||||
prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(prompts, stop=stop)
|
||||
else await self._agenerate(prompts, stop=stop, **kwargs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e, verbose=self.verbose)
|
||||
@ -275,10 +289,10 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
try:
|
||||
new_results = (
|
||||
await self._agenerate(
|
||||
missing_prompts, stop=stop, run_manager=run_manager
|
||||
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(missing_prompts, stop=stop)
|
||||
else await self._agenerate(missing_prompts, stop=stop, **kwargs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
@ -297,7 +311,11 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
||||
|
||||
def __call__(
|
||||
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input."""
|
||||
if not isinstance(prompt, str):
|
||||
@ -307,52 +325,70 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"`generate` instead."
|
||||
)
|
||||
return (
|
||||
self.generate([prompt], stop=stop, callbacks=callbacks)
|
||||
self.generate([prompt], stop=stop, callbacks=callbacks, **kwargs)
|
||||
.generations[0][0]
|
||||
.text
|
||||
)
|
||||
|
||||
async def _call_async(
|
||||
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input."""
|
||||
result = await self.agenerate([prompt], stop=stop, callbacks=callbacks)
|
||||
result = await self.agenerate(
|
||||
[prompt], stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
return result.generations[0][0].text
|
||||
|
||||
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
return self(text, stop=_stop)
|
||||
return self(text, stop=_stop, **kwargs)
|
||||
|
||||
def predict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
text = get_buffer_string(messages)
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
content = self(text, stop=_stop)
|
||||
content = self(text, stop=_stop, **kwargs)
|
||||
return AIMessage(content=content)
|
||||
|
||||
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
return await self._call_async(text, stop=_stop)
|
||||
return await self._call_async(text, stop=_stop, **kwargs)
|
||||
|
||||
async def apredict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
text = get_buffer_string(messages)
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
content = await self._call_async(text, stop=_stop)
|
||||
content = await self._call_async(text, stop=_stop, **kwargs)
|
||||
return AIMessage(content=content)
|
||||
|
||||
@property
|
||||
@ -422,6 +458,7 @@ class LLM(BaseLLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
|
||||
@ -430,6 +467,7 @@ class LLM(BaseLLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
raise NotImplementedError("Async generation not implemented for this LLM.")
|
||||
@ -439,6 +477,7 @@ class LLM(BaseLLM):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# TODO: add caching here.
|
||||
@ -446,9 +485,9 @@ class LLM(BaseLLM):
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
for prompt in prompts:
|
||||
text = (
|
||||
self._call(prompt, stop=stop, run_manager=run_manager)
|
||||
self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if new_arg_supported
|
||||
else self._call(prompt, stop=stop)
|
||||
else self._call(prompt, stop=stop, **kwargs)
|
||||
)
|
||||
generations.append([Generation(text=text)])
|
||||
return LLMResult(generations=generations)
|
||||
@ -458,15 +497,16 @@ class LLM(BaseLLM):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
generations = []
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
for prompt in prompts:
|
||||
text = (
|
||||
await self._acall(prompt, stop=stop, run_manager=run_manager)
|
||||
await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if new_arg_supported
|
||||
else await self._acall(prompt, stop=stop)
|
||||
else await self._acall(prompt, stop=stop, **kwargs)
|
||||
)
|
||||
generations.append([Generation(text=text)])
|
||||
return LLMResult(generations=generations)
|
||||
|
@ -54,6 +54,7 @@ class Baseten(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to Baseten deployed model endpoint."""
|
||||
try:
|
||||
|
@ -251,10 +251,12 @@ class Beam(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[list] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to Beam."""
|
||||
url = "https://apps.beam.cloud/" + self.app_id if self.app_id else self.url
|
||||
payload = {"prompt": prompt, "max_length": self.max_length}
|
||||
payload.update(kwargs)
|
||||
headers = {
|
||||
"Accept": "*/*",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
|
@ -155,6 +155,7 @@ class Bedrock(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Bedrock service model.
|
||||
|
||||
@ -173,10 +174,8 @@ class Bedrock(LLM):
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
provider = self.model_id.split(".")[0]
|
||||
|
||||
input_body = LLMInputOutputAdapter.prepare_input(
|
||||
provider, prompt, _model_kwargs
|
||||
)
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||
body = json.dumps(input_body)
|
||||
accept = "application/json"
|
||||
contentType = "application/json"
|
||||
|
@ -88,6 +88,7 @@ class CerebriumAI(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to CerebriumAI endpoint."""
|
||||
try:
|
||||
@ -100,7 +101,9 @@ class CerebriumAI(LLM):
|
||||
|
||||
params = self.model_kwargs or {}
|
||||
response = model_api_request(
|
||||
self.endpoint_url, {"prompt": prompt, **params}, self.cerebriumai_api_key
|
||||
self.endpoint_url,
|
||||
{"prompt": prompt, **params, **kwargs},
|
||||
self.cerebriumai_api_key,
|
||||
)
|
||||
text = response["data"]["result"]
|
||||
if stop is not None:
|
||||
|
@ -145,6 +145,7 @@ class Cohere(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Cohere's generate endpoint.
|
||||
|
||||
@ -167,7 +168,7 @@ class Cohere(LLM):
|
||||
params["stop_sequences"] = self.stop
|
||||
else:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
params = {**params, **kwargs}
|
||||
response = completion_with_retry(
|
||||
self, model=self.model, prompt=prompt, **params
|
||||
)
|
||||
|
@ -81,6 +81,7 @@ class CTransformers(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate text from a prompt.
|
||||
|
||||
|
@ -303,12 +303,14 @@ class Databricks(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Queries the LLM endpoint with the given prompt and stop sequence."""
|
||||
|
||||
# TODO: support callbacks
|
||||
|
||||
request = {"prompt": prompt, "stop": stop}
|
||||
request.update(kwargs)
|
||||
if self.model_kwargs:
|
||||
request.update(self.model_kwargs)
|
||||
|
||||
|
@ -66,6 +66,7 @@ class DeepInfra(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to DeepInfra's inference API endpoint.
|
||||
|
||||
@ -82,6 +83,7 @@ class DeepInfra(LLM):
|
||||
response = di("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_model_kwargs = {**_model_kwargs, **kwargs}
|
||||
# HTTP headers for authorization
|
||||
headers = {
|
||||
"Authorization": f"bearer {self.deepinfra_api_token}",
|
||||
|
@ -24,6 +24,7 @@ class FakeListLLM(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Return next response"""
|
||||
response = self.responses[self.i]
|
||||
@ -35,6 +36,7 @@ class FakeListLLM(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Return next response"""
|
||||
response = self.responses[self.i]
|
||||
|
@ -87,6 +87,7 @@ class ForefrontAI(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to ForefrontAI's complete endpoint.
|
||||
|
||||
@ -108,7 +109,7 @@ class ForefrontAI(LLM):
|
||||
"Authorization": f"Bearer {self.forefrontai_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"text": prompt, **self._default_params},
|
||||
json={"text": prompt, **self._default_params, **kwargs},
|
||||
)
|
||||
response_json = response.json()
|
||||
text = response_json["result"][0]["completion"]
|
||||
|
@ -134,6 +134,7 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
@ -147,6 +148,7 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
top_k=self.top_k,
|
||||
max_output_tokens=self.max_output_tokens,
|
||||
candidate_count=self.n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prompt_generations = []
|
||||
@ -163,6 +165,7 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -137,6 +137,7 @@ class GooseAI(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the GooseAI API."""
|
||||
params = self._default_params
|
||||
@ -145,6 +146,8 @@ class GooseAI(LLM):
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
|
||||
params = {**params, **kwargs}
|
||||
|
||||
response = self.client.create(engine=self.model_name, prompt=prompt, **params)
|
||||
text = response.choices[0].text
|
||||
return text
|
||||
|
@ -183,6 +183,7 @@ class GPT4All(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
r"""Call out to GPT4All's generate method.
|
||||
|
||||
@ -203,7 +204,8 @@ class GPT4All(LLM):
|
||||
if run_manager:
|
||||
text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)
|
||||
text = ""
|
||||
for token in self.client.generate(prompt, **self._default_params()):
|
||||
params = {**self._default_params(), **kwargs}
|
||||
for token in self.client.generate(prompt, **params):
|
||||
if text_callback:
|
||||
text_callback(token)
|
||||
text += token
|
||||
|
@ -96,6 +96,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to HuggingFace Hub's inference endpoint.
|
||||
|
||||
@ -114,7 +115,8 @@ class HuggingFaceEndpoint(LLM):
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
# payload samples
|
||||
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
parameter_payload = {"inputs": prompt, "parameters": params}
|
||||
|
||||
# HTTP headers for authorization
|
||||
headers = {
|
||||
|
@ -91,6 +91,7 @@ class HuggingFaceHub(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to HuggingFace Hub's inference endpoint.
|
||||
|
||||
@ -107,7 +108,8 @@ class HuggingFaceHub(LLM):
|
||||
response = hf("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
response = self.client(inputs=prompt, params=_model_kwargs)
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
response = self.client(inputs=prompt, params=params)
|
||||
if "error" in response:
|
||||
raise ValueError(f"Error raised by inference API: {response['error']}")
|
||||
if self.client.task == "text-generation":
|
||||
|
@ -164,6 +164,7 @@ class HuggingFacePipeline(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
response = self.pipeline(prompt)
|
||||
if self.pipeline.task == "text-generation":
|
||||
|
@ -113,6 +113,7 @@ class HuggingFaceTextGenInference(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if stop is None:
|
||||
stop = self.stop_sequences
|
||||
@ -130,6 +131,7 @@ class HuggingFaceTextGenInference(LLM):
|
||||
temperature=self.temperature,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
seed=self.seed,
|
||||
**kwargs,
|
||||
)
|
||||
# remove stop sequences from the end of the generated text
|
||||
for stop_seq in stop:
|
||||
|
@ -60,6 +60,7 @@ class HumanInputLLM(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Displays the prompt to the user and returns their input as a response.
|
||||
|
@ -200,6 +200,7 @@ class LlamaCpp(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the Llama model and return the output.
|
||||
|
||||
@ -227,6 +228,7 @@ class LlamaCpp(LLM):
|
||||
return combined_text_output
|
||||
else:
|
||||
params = self._get_parameters(stop)
|
||||
params = {**params, **kwargs}
|
||||
result = self.client(prompt=prompt, **params)
|
||||
return result["choices"][0]["text"]
|
||||
|
||||
|
@ -48,13 +48,15 @@ class ManifestWrapper(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to LLM through Manifest."""
|
||||
if stop is not None and len(stop) != 1:
|
||||
raise NotImplementedError(
|
||||
f"Manifest currently only supports a single stop token, got {stop}"
|
||||
)
|
||||
kwargs = self.llm_kwargs or {}
|
||||
params = self.llm_kwargs or {}
|
||||
params = {**params, **kwargs}
|
||||
if stop is not None:
|
||||
kwargs["stop_token"] = stop
|
||||
return self.client.run(prompt, **kwargs)
|
||||
params["stop_token"] = stop
|
||||
return self.client.run(prompt, **params)
|
||||
|
@ -76,9 +76,11 @@ class Modal(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to Modal endpoint."""
|
||||
params = self.model_kwargs or {}
|
||||
params = {**params, **kwargs}
|
||||
response = requests.post(
|
||||
url=self.endpoint_url,
|
||||
headers={
|
||||
|
@ -102,6 +102,7 @@ class MosaicML(LLM):
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
is_retry: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to a MosaicML LLM inference endpoint.
|
||||
|
||||
@ -123,6 +124,7 @@ class MosaicML(LLM):
|
||||
|
||||
payload = {"input_strings": [prompt]}
|
||||
payload.update(_model_kwargs)
|
||||
payload.update(kwargs)
|
||||
|
||||
# HTTP headers for authorization
|
||||
headers = {
|
||||
|
@ -117,6 +117,7 @@ class NLPCloud(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to NLPCloud's create endpoint.
|
||||
|
||||
@ -141,7 +142,6 @@ class NLPCloud(LLM):
|
||||
end_sequence = stop[0]
|
||||
else:
|
||||
end_sequence = None
|
||||
response = self.client.generation(
|
||||
prompt, end_sequence=end_sequence, **self._default_params
|
||||
)
|
||||
params = {**self._default_params, **kwargs}
|
||||
response = self.client.generation(prompt, end_sequence=end_sequence, **params)
|
||||
return response["generated_text"]
|
||||
|
@ -273,6 +273,7 @@ class BaseOpenAI(BaseLLM):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out to OpenAI's endpoint with k unique prompts.
|
||||
|
||||
@ -290,6 +291,7 @@ class BaseOpenAI(BaseLLM):
|
||||
"""
|
||||
# TODO: write a unit test for this
|
||||
params = self._invocation_params
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
@ -326,9 +328,11 @@ class BaseOpenAI(BaseLLM):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out to OpenAI's endpoint async with k unique prompts."""
|
||||
params = self._invocation_params
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
@ -771,8 +775,10 @@ class OpenAIChat(BaseLLM):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
response = ""
|
||||
params["stream"] = True
|
||||
@ -804,8 +810,10 @@ class OpenAIChat(BaseLLM):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
messages, params = self._get_chat_params(prompts, stop)
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
response = ""
|
||||
params["stream"] = True
|
||||
|
@ -137,9 +137,11 @@ class Petals(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the Petals API."""
|
||||
params = self._default_params
|
||||
params = {**params, **kwargs}
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
|
||||
outputs = self.client.generate(inputs, **params)
|
||||
text = self.tokenizer.decode(outputs[0])
|
||||
|
@ -87,6 +87,7 @@ class PipelineAI(LLM, BaseModel):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to Pipeline Cloud endpoint."""
|
||||
try:
|
||||
@ -98,6 +99,7 @@ class PipelineAI(LLM, BaseModel):
|
||||
)
|
||||
client = PipelineCloud(token=self.pipeline_api_key)
|
||||
params = self.pipeline_kwargs or {}
|
||||
params = {**params, **kwargs}
|
||||
|
||||
run = client.run_pipeline(self.pipeline_key, [prompt, params])
|
||||
try:
|
||||
|
@ -91,6 +91,7 @@ class PredictionGuard(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Prediction Guard's model API.
|
||||
Args:
|
||||
@ -117,6 +118,7 @@ class PredictionGuard(LLM):
|
||||
output=self.output,
|
||||
temperature=params["temperature"],
|
||||
max_tokens=params["max_tokens"],
|
||||
**kwargs,
|
||||
)
|
||||
text = response["choices"][0]["text"]
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""PromptLayer wrapper."""
|
||||
import datetime
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -42,6 +42,7 @@ class PromptLayerOpenAI(OpenAI):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call OpenAI generate and then call PromptLayer API to log the request."""
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||
@ -56,11 +57,12 @@ class PromptLayerOpenAI(OpenAI):
|
||||
"text": generation.text,
|
||||
"llm_output": generated_responses.llm_output,
|
||||
}
|
||||
params = {**self._identifying_params, **kwargs}
|
||||
pl_request_id = promptlayer_api_request(
|
||||
"langchain.PromptLayerOpenAI",
|
||||
"langchain",
|
||||
[prompt],
|
||||
self._identifying_params,
|
||||
params,
|
||||
self.pl_tags,
|
||||
resp,
|
||||
request_start_time,
|
||||
@ -81,6 +83,7 @@ class PromptLayerOpenAI(OpenAI):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
||||
|
||||
@ -94,11 +97,12 @@ class PromptLayerOpenAI(OpenAI):
|
||||
"text": generation.text,
|
||||
"llm_output": generated_responses.llm_output,
|
||||
}
|
||||
params = {**self._identifying_params, **kwargs}
|
||||
pl_request_id = await promptlayer_api_request_async(
|
||||
"langchain.PromptLayerOpenAI.async",
|
||||
"langchain",
|
||||
[prompt],
|
||||
self._identifying_params,
|
||||
params,
|
||||
self.pl_tags,
|
||||
resp,
|
||||
request_start_time,
|
||||
@ -147,6 +151,7 @@ class PromptLayerOpenAIChat(OpenAIChat):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call OpenAI generate and then call PromptLayer API to log the request."""
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||
@ -161,11 +166,12 @@ class PromptLayerOpenAIChat(OpenAIChat):
|
||||
"text": generation.text,
|
||||
"llm_output": generated_responses.llm_output,
|
||||
}
|
||||
params = {**self._identifying_params, **kwargs}
|
||||
pl_request_id = promptlayer_api_request(
|
||||
"langchain.PromptLayerOpenAIChat",
|
||||
"langchain",
|
||||
[prompt],
|
||||
self._identifying_params,
|
||||
params,
|
||||
self.pl_tags,
|
||||
resp,
|
||||
request_start_time,
|
||||
@ -186,6 +192,7 @@ class PromptLayerOpenAIChat(OpenAIChat):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
||||
|
||||
@ -199,11 +206,12 @@ class PromptLayerOpenAIChat(OpenAIChat):
|
||||
"text": generation.text,
|
||||
"llm_output": generated_responses.llm_output,
|
||||
}
|
||||
params = {**self._identifying_params, **kwargs}
|
||||
pl_request_id = await promptlayer_api_request_async(
|
||||
"langchain.PromptLayerOpenAIChat.async",
|
||||
"langchain",
|
||||
[prompt],
|
||||
self._identifying_params,
|
||||
params,
|
||||
self.pl_tags,
|
||||
resp,
|
||||
request_start_time,
|
||||
|
@ -85,6 +85,7 @@ class Replicate(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to replicate endpoint."""
|
||||
try:
|
||||
@ -110,6 +111,6 @@ class Replicate(LLM):
|
||||
first_input_name = input_properties[0][0]
|
||||
|
||||
inputs = {first_input_name: prompt, **self.input}
|
||||
iterator = replicate_python.run(self.model, input={**inputs})
|
||||
iterator = replicate_python.run(self.model, input={**inputs, **kwargs})
|
||||
|
||||
return "".join([output for output in iterator])
|
||||
|
@ -210,6 +210,7 @@ class RWKV(LLM, BaseModel):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
r"""RWKV generation
|
||||
|
||||
|
@ -207,6 +207,7 @@ class SagemakerEndpoint(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Sagemaker inference endpoint.
|
||||
|
||||
@ -223,6 +224,7 @@ class SagemakerEndpoint(LLM):
|
||||
response = se("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_model_kwargs = {**_model_kwargs, **kwargs}
|
||||
_endpoint_kwargs = self.endpoint_kwargs or {}
|
||||
|
||||
body = self.content_handler.transform_input(prompt, _model_kwargs)
|
||||
|
@ -214,5 +214,8 @@ class SelfHostedPipeline(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop)
|
||||
return self.client(
|
||||
pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs
|
||||
)
|
||||
|
@ -207,5 +207,8 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop)
|
||||
return self.client(
|
||||
pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs
|
||||
)
|
||||
|
@ -86,6 +86,7 @@ class StochasticAI(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to StochasticAI's complete endpoint.
|
||||
|
||||
@ -102,6 +103,7 @@ class StochasticAI(LLM):
|
||||
response = StochasticAI("Tell me a joke.")
|
||||
"""
|
||||
params = self.model_kwargs or {}
|
||||
params = {**params, **kwargs}
|
||||
response_post = requests.post(
|
||||
url=self.api_url,
|
||||
json={"prompt": prompt, "params": params},
|
||||
|
@ -50,8 +50,11 @@ class _VertexAICommon(BaseModel):
|
||||
}
|
||||
return {**base_params}
|
||||
|
||||
def _predict(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
res = self.client.predict(prompt, **self._default_params)
|
||||
def _predict(
|
||||
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
params = {**self._default_params, **kwargs}
|
||||
res = self.client.predict(prompt, **params)
|
||||
return self._enforce_stop_words(res.text, stop)
|
||||
|
||||
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
||||
@ -100,6 +103,7 @@ class VertexAI(_VertexAICommon, LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call Vertex model to get predictions based on the prompt.
|
||||
|
||||
@ -111,4 +115,4 @@ class VertexAI(_VertexAICommon, LLM):
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
return self._predict(prompt, stop)
|
||||
return self._predict(prompt, stop, **kwargs)
|
||||
|
@ -118,6 +118,7 @@ class Writer(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Writer's completions endpoint.
|
||||
|
||||
@ -141,7 +142,7 @@ class Writer(LLM):
|
||||
f"/organization/{self.writer_org_id}"
|
||||
f"/model/{self.model_id}/completions"
|
||||
)
|
||||
|
||||
params = {**self._default_params, **kwargs}
|
||||
response = requests.post(
|
||||
url=base_url,
|
||||
headers={
|
||||
@ -149,7 +150,7 @@ class Writer(LLM):
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
json={"prompt": prompt, **self._default_params},
|
||||
json={"prompt": prompt, **params},
|
||||
)
|
||||
text = response.text
|
||||
if stop is not None:
|
||||
|
@ -20,6 +20,7 @@ class FakeListLLM(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Increment counter, and then return response in that index."""
|
||||
self.i += 1
|
||||
|
@ -38,6 +38,7 @@ class FakeListLLM(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Increment counter, and then return response in that index."""
|
||||
self.i += 1
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Test HyDE."""
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -36,6 +36,7 @@ class FakeLLM(BaseLLM):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||
|
||||
@ -44,6 +45,7 @@ class FakeLLM(BaseLLM):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||
|
||||
|
@ -15,6 +15,7 @@ class FakeLLM(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Return `foo` if longer than 10000 words, else `bar`."""
|
||||
if len(prompt) > 10000:
|
||||
|
@ -17,6 +17,7 @@ class FakeChatModel(SimpleChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return "fake response"
|
||||
|
||||
@ -25,6 +26,7 @@ class FakeChatModel(SimpleChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = "fake response"
|
||||
message = AIMessage(content=output_str)
|
||||
|
@ -34,6 +34,7 @@ class FakeLLM(LLM):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if self.sequential_responses:
|
||||
return self._get_next_response_in_sequence
|
||||
|
Loading…
Reference in New Issue
Block a user