support kwargs (#5990)

This commit is contained in:
Harrison Chase 2023-06-11 10:09:22 -07:00 committed by GitHub
parent b934677a81
commit 704d56e241
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 289 additions and 88 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Sequence, Set from typing import Any, List, Optional, Sequence, Set
from pydantic import BaseModel from pydantic import BaseModel
@ -36,6 +36,7 @@ class BaseLanguageModel(BaseModel, ABC):
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult.""" """Take in a list of prompt values and return an LLMResult."""
@ -45,26 +46,39 @@ class BaseLanguageModel(BaseModel, ABC):
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult.""" """Take in a list of prompt values and return an LLMResult."""
@abstractmethod @abstractmethod
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Predict text from text.""" """Predict text from text."""
@abstractmethod @abstractmethod
def predict_messages( def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
"""Predict message from messages.""" """Predict message from messages."""
@abstractmethod @abstractmethod
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Predict text from text.""" """Predict text from text."""
@abstractmethod @abstractmethod
async def apredict_messages( async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
"""Predict message from messages.""" """Predict message from messages."""

View File

@ -94,9 +94,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages) prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params} params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop: if stop:
params["stop_sequences"] = stop params["stop_sequences"] = stop
@ -121,9 +122,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages) prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params} params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop: if stop:
params["stop_sequences"] = stop params["stop_sequences"] = stop

View File

@ -64,6 +64,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call"""
@ -82,7 +83,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
) )
try: try:
results = [ results = [
self._generate(m, stop=stop, run_manager=run_manager) self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported if new_arg_supported
else self._generate(m, stop=stop) else self._generate(m, stop=stop)
for m in messages for m in messages
@ -103,6 +104,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call"""
params = self.dict() params = self.dict()
@ -121,7 +123,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
try: try:
results = await asyncio.gather( results = await asyncio.gather(
*[ *[
self._agenerate(m, stop=stop, run_manager=run_manager) self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported if new_arg_supported
else self._agenerate(m, stop=stop) else self._agenerate(m, stop=stop)
for m in messages for m in messages
@ -143,18 +145,22 @@ class BaseChatModel(BaseLanguageModel, ABC):
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts] prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks) return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
async def agenerate_prompt( async def agenerate_prompt(
self, self,
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts] prompt_messages = [p.to_messages() for p in prompts]
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks) return await self.agenerate(
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
)
@abstractmethod @abstractmethod
def _generate( def _generate(
@ -162,6 +168,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call"""
@ -171,6 +178,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call"""
@ -193,18 +201,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
result = await self.agenerate([messages], stop=stop, callbacks=callbacks) result = await self.agenerate(
[messages], stop=stop, callbacks=callbacks, **kwargs
)
generation = result.generations[0][0] generation = result.generations[0][0]
if isinstance(generation, ChatGeneration): if isinstance(generation, ChatGeneration):
return generation.message return generation.message
else: else:
raise ValueError("Unexpected generation type") raise ValueError("Unexpected generation type")
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str: def call_as_llm(
return self.predict(message, stop=stop) self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
return self.predict(message, stop=stop, **kwargs)
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
@ -213,30 +228,42 @@ class BaseChatModel(BaseLanguageModel, ABC):
return result.content return result.content
def predict_messages( def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
return self(messages, stop=_stop) return self(messages, stop=_stop, **kwargs)
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
result = await self._call_async([HumanMessage(content=text)], stop=_stop) result = await self._call_async(
[HumanMessage(content=text)], stop=_stop, **kwargs
)
return result.content return result.content
async def apredict_messages( async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
return await self._call_async(messages, stop=_stop) return await self._call_async(messages, stop=_stop, **kwargs)
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
@ -261,8 +288,9 @@ class SimpleChatModel(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager) output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str) message = AIMessage(content=output_str)
generation = ChatGeneration(message=message) generation = ChatGeneration(message=message)
return ChatResult(generations=[generation]) return ChatResult(generations=[generation])
@ -273,6 +301,7 @@ class SimpleChatModel(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Simpler interface.""" """Simpler interface."""
@ -281,6 +310,9 @@ class SimpleChatModel(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
func = partial(self._generate, messages, stop=stop, run_manager=run_manager) func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func) return await asyncio.get_event_loop().run_in_executor(None, func)

View File

@ -280,6 +280,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
prompt = _messages_to_prompt_dict(messages) prompt = _messages_to_prompt_dict(messages)
@ -291,6 +292,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
top_p=self.top_p, top_p=self.top_p,
top_k=self.top_k, top_k=self.top_k,
candidate_count=self.n, candidate_count=self.n,
**kwargs,
) )
return _response_to_result(response, stop) return _response_to_result(response, stop)
@ -300,6 +302,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
prompt = _messages_to_prompt_dict(messages) prompt = _messages_to_prompt_dict(messages)

View File

@ -302,8 +302,10 @@ class ChatOpenAI(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming: if self.streaming:
inner_completion = "" inner_completion = ""
role = "assistant" role = "assistant"
@ -348,8 +350,10 @@ class ChatOpenAI(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming: if self.streaming:
inner_completion = "" inner_completion = ""
role = "assistant" role = "assistant"

View File

@ -42,6 +42,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult: ) -> ChatResult:
"""Call ChatOpenAI generate and then call PromptLayer API to log the request.""" """Call ChatOpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request from promptlayer.utils import get_api_key, promptlayer_api_request
@ -54,6 +55,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
response_dict, params = super()._create_message_dicts( response_dict, params = super()._create_message_dicts(
[generation.message], stop [generation.message], stop
) )
params = {**params, **kwargs}
pl_request_id = promptlayer_api_request( pl_request_id = promptlayer_api_request(
"langchain.PromptLayerChatOpenAI", "langchain.PromptLayerChatOpenAI",
"langchain", "langchain",
@ -79,6 +81,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult: ) -> ChatResult:
"""Call ChatOpenAI agenerate and then call PromptLayer to log.""" """Call ChatOpenAI agenerate and then call PromptLayer to log."""
from promptlayer.utils import get_api_key, promptlayer_api_request_async from promptlayer.utils import get_api_key, promptlayer_api_request_async
@ -91,6 +94,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
response_dict, params = super()._create_message_dicts( response_dict, params = super()._create_message_dicts(
[generation.message], stop [generation.message], stop
) )
params = {**params, **kwargs}
pl_request_id = await promptlayer_api_request_async( pl_request_id = await promptlayer_api_request_async(
"langchain.PromptLayerChatOpenAI.async", "langchain.PromptLayerChatOpenAI.async",
"langchain", "langchain",

View File

@ -1,6 +1,6 @@
"""Wrapper around Google VertexAI chat-based models.""" """Wrapper around Google VertexAI chat-based models."""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import root_validator from pydantic import root_validator
@ -93,6 +93,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Generate next turn in the conversation. """Generate next turn in the conversation.
@ -119,7 +120,8 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
history = _parse_chat_history(messages[:-1]) history = _parse_chat_history(messages[:-1])
context = history.system_message.content if history.system_message else None context = history.system_message.content if history.system_message else None
chat = self.client.start_chat(context=context, **self._default_params) params = {**self._default_params, **kwargs}
chat = self.client.start_chat(context=context, **params)
for pair in history.history: for pair in history.history:
chat._history.append((pair.question.content, pair.answer.content)) chat._history.append((pair.question.content, pair.answer.content))
response = chat.send_message(question.content, **self._default_params) response = chat.send_message(question.content, **self._default_params)
@ -131,6 +133,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
raise NotImplementedError( raise NotImplementedError(
"""Vertex AI doesn't support async requests at the moment.""" """Vertex AI doesn't support async requests at the moment."""

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
from typing import TYPE_CHECKING, List, Optional, cast from typing import TYPE_CHECKING, Any, List, Optional, cast
from pydantic import Field, root_validator from pydantic import Field, root_validator
@ -42,6 +42,7 @@ class JsonFormer(HuggingFacePipeline):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
jsonformer = import_jsonformer() jsonformer = import_jsonformer()
from transformers import Text2TextGenerationPipeline from transformers import Text2TextGenerationPipeline

View File

@ -1,7 +1,7 @@
"""Experimental implementation of RELLM wrapped LLM.""" """Experimental implementation of RELLM wrapped LLM."""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional, cast from typing import TYPE_CHECKING, Any, List, Optional, cast
from pydantic import Field, root_validator from pydantic import Field, root_validator
@ -47,6 +47,7 @@ class RELLM(HuggingFacePipeline):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
rellm = import_rellm() rellm = import_rellm()
from transformers import Text2TextGenerationPipeline from transformers import Text2TextGenerationPipeline

View File

@ -112,6 +112,7 @@ class AI21(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to AI21's complete endpoint. """Call out to AI21's complete endpoint.
@ -140,10 +141,11 @@ class AI21(LLM):
base_url = "https://api.ai21.com/studio/v1/experimental" base_url = "https://api.ai21.com/studio/v1/experimental"
else: else:
base_url = "https://api.ai21.com/studio/v1" base_url = "https://api.ai21.com/studio/v1"
params = {**self._default_params, **kwargs}
response = requests.post( response = requests.post(
url=f"{base_url}/{self.model}/complete", url=f"{base_url}/{self.model}/complete",
headers={"Authorization": f"Bearer {self.ai21_api_key}"}, headers={"Authorization": f"Bearer {self.ai21_api_key}"},
json={"prompt": prompt, "stopSequences": stop, **self._default_params}, json={"prompt": prompt, "stopSequences": stop, **params},
) )
if response.status_code != 200: if response.status_code != 200:
optional_detail = response.json().get("error") optional_detail = response.json().get("error")

View File

@ -206,6 +206,7 @@ class AlephAlpha(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Aleph Alpha's completion endpoint. """Call out to Aleph Alpha's completion endpoint.
@ -232,6 +233,7 @@ class AlephAlpha(LLM):
params["stop_sequences"] = self.stop_sequences params["stop_sequences"] = self.stop_sequences
else: else:
params["stop_sequences"] = stop params["stop_sequences"] = stop
params = {**params, **kwargs}
request = CompletionRequest(prompt=Prompt.from_text(prompt), **params) request = CompletionRequest(prompt=Prompt.from_text(prompt), **params)
response = self.client.complete(model=self.model, request=request) response = self.client.complete(model=self.model, request=request)
text = response.completions[0].completion text = response.completions[0].completion

View File

@ -162,6 +162,7 @@ class Anthropic(LLM, _AnthropicCommon):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
r"""Call out to Anthropic's completion endpoint. r"""Call out to Anthropic's completion endpoint.
@ -181,11 +182,12 @@ class Anthropic(LLM, _AnthropicCommon):
""" """
stop = self._get_anthropic_stop(stop) stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
if self.streaming: if self.streaming:
stream_resp = self.client.completion_stream( stream_resp = self.client.completion_stream(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
**self._default_params, **params,
) )
current_completion = "" current_completion = ""
for data in stream_resp: for data in stream_resp:
@ -197,7 +199,7 @@ class Anthropic(LLM, _AnthropicCommon):
response = self.client.completion( response = self.client.completion(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
**self._default_params, **params,
) )
return response["completion"] return response["completion"]
@ -206,14 +208,16 @@ class Anthropic(LLM, _AnthropicCommon):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Anthropic's completion endpoint asynchronously.""" """Call out to Anthropic's completion endpoint asynchronously."""
stop = self._get_anthropic_stop(stop) stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
if self.streaming: if self.streaming:
stream_resp = await self.client.acompletion_stream( stream_resp = await self.client.acompletion_stream(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
**self._default_params, **params,
) )
current_completion = "" current_completion = ""
async for data in stream_resp: async for data in stream_resp:
@ -225,7 +229,7 @@ class Anthropic(LLM, _AnthropicCommon):
response = await self.client.acompletion( response = await self.client.acompletion(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
**self._default_params, **params,
) )
return response["completion"] return response["completion"]

View File

@ -88,6 +88,7 @@ class Anyscale(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Anyscale Service endpoint. """Call out to Anyscale Service endpoint.
Args: Args:

View File

@ -105,6 +105,7 @@ class Aviary(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Aviary """Call out to Aviary
Args: Args:

View File

@ -87,6 +87,7 @@ class Banana(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to Banana endpoint.""" """Call to Banana endpoint."""
try: try:
@ -97,6 +98,7 @@ class Banana(LLM):
"Please install it with `pip install banana-dev`." "Please install it with `pip install banana-dev`."
) )
params = self.model_kwargs or {} params = self.model_kwargs or {}
params = {**params, **kwargs}
api_key = self.banana_api_key api_key = self.banana_api_key
model_key = self.model_key model_key = self.model_key
model_inputs = { model_inputs = {

View File

@ -113,6 +113,7 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompts.""" """Run the LLM on the given prompts."""
@ -122,6 +123,7 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompts.""" """Run the LLM on the given prompts."""
@ -130,24 +132,29 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts] prompt_strings = [p.to_string() for p in prompts]
return self.generate(prompt_strings, stop=stop, callbacks=callbacks) return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
async def agenerate_prompt( async def agenerate_prompt(
self, self,
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts] prompt_strings = [p.to_string() for p in prompts]
return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks) return await self.agenerate(
prompt_strings, stop=stop, callbacks=callbacks, **kwargs
)
def generate( def generate(
self, self,
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will # If string is passed in directly no errors will be raised but outputs will
@ -183,9 +190,11 @@ class BaseLLM(BaseLanguageModel, ABC):
) )
try: try:
output = ( output = (
self._generate(prompts, stop=stop, run_manager=run_manager) self._generate(
prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported if new_arg_supported
else self._generate(prompts, stop=stop) else self._generate(prompts, stop=stop, **kwargs)
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e) run_manager.on_llm_error(e)
@ -202,9 +211,11 @@ class BaseLLM(BaseLanguageModel, ABC):
) )
try: try:
new_results = ( new_results = (
self._generate(missing_prompts, stop=stop, run_manager=run_manager) self._generate(
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported if new_arg_supported
else self._generate(missing_prompts, stop=stop) else self._generate(missing_prompts, stop=stop, **kwargs)
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e) run_manager.on_llm_error(e)
@ -227,6 +238,7 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
params = self.dict() params = self.dict()
@ -255,9 +267,11 @@ class BaseLLM(BaseLanguageModel, ABC):
) )
try: try:
output = ( output = (
await self._agenerate(prompts, stop=stop, run_manager=run_manager) await self._agenerate(
prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported if new_arg_supported
else await self._agenerate(prompts, stop=stop) else await self._agenerate(prompts, stop=stop, **kwargs)
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e, verbose=self.verbose) await run_manager.on_llm_error(e, verbose=self.verbose)
@ -275,10 +289,10 @@ class BaseLLM(BaseLanguageModel, ABC):
try: try:
new_results = ( new_results = (
await self._agenerate( await self._agenerate(
missing_prompts, stop=stop, run_manager=run_manager missing_prompts, stop=stop, run_manager=run_manager, **kwargs
) )
if new_arg_supported if new_arg_supported
else await self._agenerate(missing_prompts, stop=stop) else await self._agenerate(missing_prompts, stop=stop, **kwargs)
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e) await run_manager.on_llm_error(e)
@ -297,7 +311,11 @@ class BaseLLM(BaseLanguageModel, ABC):
return LLMResult(generations=generations, llm_output=llm_output, run=run_info) return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
def __call__( def __call__(
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None self,
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> str: ) -> str:
"""Check Cache and run the LLM on the given prompt and input.""" """Check Cache and run the LLM on the given prompt and input."""
if not isinstance(prompt, str): if not isinstance(prompt, str):
@ -307,52 +325,70 @@ class BaseLLM(BaseLanguageModel, ABC):
"`generate` instead." "`generate` instead."
) )
return ( return (
self.generate([prompt], stop=stop, callbacks=callbacks) self.generate([prompt], stop=stop, callbacks=callbacks, **kwargs)
.generations[0][0] .generations[0][0]
.text .text
) )
async def _call_async( async def _call_async(
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None self,
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> str: ) -> str:
"""Check Cache and run the LLM on the given prompt and input.""" """Check Cache and run the LLM on the given prompt and input."""
result = await self.agenerate([prompt], stop=stop, callbacks=callbacks) result = await self.agenerate(
[prompt], stop=stop, callbacks=callbacks, **kwargs
)
return result.generations[0][0].text return result.generations[0][0].text
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
return self(text, stop=_stop) return self(text, stop=_stop, **kwargs)
def predict_messages( def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
text = get_buffer_string(messages) text = get_buffer_string(messages)
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
content = self(text, stop=_stop) content = self(text, stop=_stop, **kwargs)
return AIMessage(content=content) return AIMessage(content=content)
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
return await self._call_async(text, stop=_stop) return await self._call_async(text, stop=_stop, **kwargs)
async def apredict_messages( async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
text = get_buffer_string(messages) text = get_buffer_string(messages)
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
content = await self._call_async(text, stop=_stop) content = await self._call_async(text, stop=_stop, **kwargs)
return AIMessage(content=content) return AIMessage(content=content)
@property @property
@ -422,6 +458,7 @@ class LLM(BaseLLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
@ -430,6 +467,7 @@ class LLM(BaseLLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
raise NotImplementedError("Async generation not implemented for this LLM.") raise NotImplementedError("Async generation not implemented for this LLM.")
@ -439,6 +477,7 @@ class LLM(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
# TODO: add caching here. # TODO: add caching here.
@ -446,9 +485,9 @@ class LLM(BaseLLM):
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
for prompt in prompts: for prompt in prompts:
text = ( text = (
self._call(prompt, stop=stop, run_manager=run_manager) self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported if new_arg_supported
else self._call(prompt, stop=stop) else self._call(prompt, stop=stop, **kwargs)
) )
generations.append([Generation(text=text)]) generations.append([Generation(text=text)])
return LLMResult(generations=generations) return LLMResult(generations=generations)
@ -458,15 +497,16 @@ class LLM(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
generations = [] generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
for prompt in prompts: for prompt in prompts:
text = ( text = (
await self._acall(prompt, stop=stop, run_manager=run_manager) await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported if new_arg_supported
else await self._acall(prompt, stop=stop) else await self._acall(prompt, stop=stop, **kwargs)
) )
generations.append([Generation(text=text)]) generations.append([Generation(text=text)])
return LLMResult(generations=generations) return LLMResult(generations=generations)

View File

@ -54,6 +54,7 @@ class Baseten(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to Baseten deployed model endpoint.""" """Call to Baseten deployed model endpoint."""
try: try:

View File

@ -251,10 +251,12 @@ class Beam(LLM):
prompt: str, prompt: str,
stop: Optional[list] = None, stop: Optional[list] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to Beam.""" """Call to Beam."""
url = "https://apps.beam.cloud/" + self.app_id if self.app_id else self.url url = "https://apps.beam.cloud/" + self.app_id if self.app_id else self.url
payload = {"prompt": prompt, "max_length": self.max_length} payload = {"prompt": prompt, "max_length": self.max_length}
payload.update(kwargs)
headers = { headers = {
"Accept": "*/*", "Accept": "*/*",
"Accept-Encoding": "gzip, deflate", "Accept-Encoding": "gzip, deflate",

View File

@ -155,6 +155,7 @@ class Bedrock(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Bedrock service model. """Call out to Bedrock service model.
@ -173,10 +174,8 @@ class Bedrock(LLM):
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
provider = self.model_id.split(".")[0] provider = self.model_id.split(".")[0]
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input( input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
provider, prompt, _model_kwargs
)
body = json.dumps(input_body) body = json.dumps(input_body)
accept = "application/json" accept = "application/json"
contentType = "application/json" contentType = "application/json"

View File

@ -88,6 +88,7 @@ class CerebriumAI(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to CerebriumAI endpoint.""" """Call to CerebriumAI endpoint."""
try: try:
@ -100,7 +101,9 @@ class CerebriumAI(LLM):
params = self.model_kwargs or {} params = self.model_kwargs or {}
response = model_api_request( response = model_api_request(
self.endpoint_url, {"prompt": prompt, **params}, self.cerebriumai_api_key self.endpoint_url,
{"prompt": prompt, **params, **kwargs},
self.cerebriumai_api_key,
) )
text = response["data"]["result"] text = response["data"]["result"]
if stop is not None: if stop is not None:

View File

@ -145,6 +145,7 @@ class Cohere(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Cohere's generate endpoint. """Call out to Cohere's generate endpoint.
@ -167,7 +168,7 @@ class Cohere(LLM):
params["stop_sequences"] = self.stop params["stop_sequences"] = self.stop
else: else:
params["stop_sequences"] = stop params["stop_sequences"] = stop
params = {**params, **kwargs}
response = completion_with_retry( response = completion_with_retry(
self, model=self.model, prompt=prompt, **params self, model=self.model, prompt=prompt, **params
) )

View File

@ -81,6 +81,7 @@ class CTransformers(LLM):
prompt: str, prompt: str,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Generate text from a prompt. """Generate text from a prompt.

View File

@ -303,12 +303,14 @@ class Databricks(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Queries the LLM endpoint with the given prompt and stop sequence.""" """Queries the LLM endpoint with the given prompt and stop sequence."""
# TODO: support callbacks # TODO: support callbacks
request = {"prompt": prompt, "stop": stop} request = {"prompt": prompt, "stop": stop}
request.update(kwargs)
if self.model_kwargs: if self.model_kwargs:
request.update(self.model_kwargs) request.update(self.model_kwargs)

View File

@ -66,6 +66,7 @@ class DeepInfra(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to DeepInfra's inference API endpoint. """Call out to DeepInfra's inference API endpoint.
@ -82,6 +83,7 @@ class DeepInfra(LLM):
response = di("Tell me a joke.") response = di("Tell me a joke.")
""" """
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
_model_kwargs = {**_model_kwargs, **kwargs}
# HTTP headers for authorization # HTTP headers for authorization
headers = { headers = {
"Authorization": f"bearer {self.deepinfra_api_token}", "Authorization": f"bearer {self.deepinfra_api_token}",

View File

@ -24,6 +24,7 @@ class FakeListLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Return next response""" """Return next response"""
response = self.responses[self.i] response = self.responses[self.i]
@ -35,6 +36,7 @@ class FakeListLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Return next response""" """Return next response"""
response = self.responses[self.i] response = self.responses[self.i]

View File

@ -87,6 +87,7 @@ class ForefrontAI(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to ForefrontAI's complete endpoint. """Call out to ForefrontAI's complete endpoint.
@ -108,7 +109,7 @@ class ForefrontAI(LLM):
"Authorization": f"Bearer {self.forefrontai_api_key}", "Authorization": f"Bearer {self.forefrontai_api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
json={"text": prompt, **self._default_params}, json={"text": prompt, **self._default_params, **kwargs},
) )
response_json = response.json() response_json = response.json()
text = response_json["result"][0]["completion"] text = response_json["result"][0]["completion"]

View File

@ -134,6 +134,7 @@ class GooglePalm(BaseLLM, BaseModel):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
generations = [] generations = []
for prompt in prompts: for prompt in prompts:
@ -147,6 +148,7 @@ class GooglePalm(BaseLLM, BaseModel):
top_k=self.top_k, top_k=self.top_k,
max_output_tokens=self.max_output_tokens, max_output_tokens=self.max_output_tokens,
candidate_count=self.n, candidate_count=self.n,
**kwargs,
) )
prompt_generations = [] prompt_generations = []
@ -163,6 +165,7 @@ class GooglePalm(BaseLLM, BaseModel):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
raise NotImplementedError() raise NotImplementedError()

View File

@ -137,6 +137,7 @@ class GooseAI(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call the GooseAI API.""" """Call the GooseAI API."""
params = self._default_params params = self._default_params
@ -145,6 +146,8 @@ class GooseAI(LLM):
raise ValueError("`stop` found in both the input and default params.") raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
params = {**params, **kwargs}
response = self.client.create(engine=self.model_name, prompt=prompt, **params) response = self.client.create(engine=self.model_name, prompt=prompt, **params)
text = response.choices[0].text text = response.choices[0].text
return text return text

View File

@ -183,6 +183,7 @@ class GPT4All(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
r"""Call out to GPT4All's generate method. r"""Call out to GPT4All's generate method.
@ -203,7 +204,8 @@ class GPT4All(LLM):
if run_manager: if run_manager:
text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose) text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)
text = "" text = ""
for token in self.client.generate(prompt, **self._default_params()): params = {**self._default_params(), **kwargs}
for token in self.client.generate(prompt, **params):
if text_callback: if text_callback:
text_callback(token) text_callback(token)
text += token text += token

View File

@ -96,6 +96,7 @@ class HuggingFaceEndpoint(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to HuggingFace Hub's inference endpoint. """Call out to HuggingFace Hub's inference endpoint.
@ -114,7 +115,8 @@ class HuggingFaceEndpoint(LLM):
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
# payload samples # payload samples
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs} params = {**_model_kwargs, **kwargs}
parameter_payload = {"inputs": prompt, "parameters": params}
# HTTP headers for authorization # HTTP headers for authorization
headers = { headers = {

View File

@ -91,6 +91,7 @@ class HuggingFaceHub(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to HuggingFace Hub's inference endpoint. """Call out to HuggingFace Hub's inference endpoint.
@ -107,7 +108,8 @@ class HuggingFaceHub(LLM):
response = hf("Tell me a joke.") response = hf("Tell me a joke.")
""" """
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
response = self.client(inputs=prompt, params=_model_kwargs) params = {**_model_kwargs, **kwargs}
response = self.client(inputs=prompt, params=params)
if "error" in response: if "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}") raise ValueError(f"Error raised by inference API: {response['error']}")
if self.client.task == "text-generation": if self.client.task == "text-generation":

View File

@ -164,6 +164,7 @@ class HuggingFacePipeline(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
response = self.pipeline(prompt) response = self.pipeline(prompt)
if self.pipeline.task == "text-generation": if self.pipeline.task == "text-generation":

View File

@ -113,6 +113,7 @@ class HuggingFaceTextGenInference(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
if stop is None: if stop is None:
stop = self.stop_sequences stop = self.stop_sequences
@ -130,6 +131,7 @@ class HuggingFaceTextGenInference(LLM):
temperature=self.temperature, temperature=self.temperature,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
seed=self.seed, seed=self.seed,
**kwargs,
) )
# remove stop sequences from the end of the generated text # remove stop sequences from the end of the generated text
for stop_seq in stop: for stop_seq in stop:

View File

@ -60,6 +60,7 @@ class HumanInputLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
""" """
Displays the prompt to the user and returns their input as a response. Displays the prompt to the user and returns their input as a response.

View File

@ -200,6 +200,7 @@ class LlamaCpp(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call the Llama model and return the output. """Call the Llama model and return the output.
@ -227,6 +228,7 @@ class LlamaCpp(LLM):
return combined_text_output return combined_text_output
else: else:
params = self._get_parameters(stop) params = self._get_parameters(stop)
params = {**params, **kwargs}
result = self.client(prompt=prompt, **params) result = self.client(prompt=prompt, **params)
return result["choices"][0]["text"] return result["choices"][0]["text"]

View File

@ -48,13 +48,15 @@ class ManifestWrapper(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to LLM through Manifest.""" """Call out to LLM through Manifest."""
if stop is not None and len(stop) != 1: if stop is not None and len(stop) != 1:
raise NotImplementedError( raise NotImplementedError(
f"Manifest currently only supports a single stop token, got {stop}" f"Manifest currently only supports a single stop token, got {stop}"
) )
kwargs = self.llm_kwargs or {} params = self.llm_kwargs or {}
params = {**params, **kwargs}
if stop is not None: if stop is not None:
kwargs["stop_token"] = stop params["stop_token"] = stop
return self.client.run(prompt, **kwargs) return self.client.run(prompt, **params)

View File

@ -76,9 +76,11 @@ class Modal(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to Modal endpoint.""" """Call to Modal endpoint."""
params = self.model_kwargs or {} params = self.model_kwargs or {}
params = {**params, **kwargs}
response = requests.post( response = requests.post(
url=self.endpoint_url, url=self.endpoint_url,
headers={ headers={

View File

@ -102,6 +102,7 @@ class MosaicML(LLM):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
is_retry: bool = False, is_retry: bool = False,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to a MosaicML LLM inference endpoint. """Call out to a MosaicML LLM inference endpoint.
@ -123,6 +124,7 @@ class MosaicML(LLM):
payload = {"input_strings": [prompt]} payload = {"input_strings": [prompt]}
payload.update(_model_kwargs) payload.update(_model_kwargs)
payload.update(kwargs)
# HTTP headers for authorization # HTTP headers for authorization
headers = { headers = {

View File

@ -117,6 +117,7 @@ class NLPCloud(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to NLPCloud's create endpoint. """Call out to NLPCloud's create endpoint.
@ -141,7 +142,6 @@ class NLPCloud(LLM):
end_sequence = stop[0] end_sequence = stop[0]
else: else:
end_sequence = None end_sequence = None
response = self.client.generation( params = {**self._default_params, **kwargs}
prompt, end_sequence=end_sequence, **self._default_params response = self.client.generation(prompt, end_sequence=end_sequence, **params)
)
return response["generated_text"] return response["generated_text"]

View File

@ -273,6 +273,7 @@ class BaseOpenAI(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Call out to OpenAI's endpoint with k unique prompts. """Call out to OpenAI's endpoint with k unique prompts.
@ -290,6 +291,7 @@ class BaseOpenAI(BaseLLM):
""" """
# TODO: write a unit test for this # TODO: write a unit test for this
params = self._invocation_params params = self._invocation_params
params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop) sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = [] choices = []
token_usage: Dict[str, int] = {} token_usage: Dict[str, int] = {}
@ -326,9 +328,11 @@ class BaseOpenAI(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Call out to OpenAI's endpoint async with k unique prompts.""" """Call out to OpenAI's endpoint async with k unique prompts."""
params = self._invocation_params params = self._invocation_params
params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop) sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = [] choices = []
token_usage: Dict[str, int] = {} token_usage: Dict[str, int] = {}
@ -771,8 +775,10 @@ class OpenAIChat(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop) messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
if self.streaming: if self.streaming:
response = "" response = ""
params["stream"] = True params["stream"] = True
@ -804,8 +810,10 @@ class OpenAIChat(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop) messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
if self.streaming: if self.streaming:
response = "" response = ""
params["stream"] = True params["stream"] = True

View File

@ -137,9 +137,11 @@ class Petals(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call the Petals API.""" """Call the Petals API."""
params = self._default_params params = self._default_params
params = {**params, **kwargs}
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"] inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
outputs = self.client.generate(inputs, **params) outputs = self.client.generate(inputs, **params)
text = self.tokenizer.decode(outputs[0]) text = self.tokenizer.decode(outputs[0])

View File

@ -87,6 +87,7 @@ class PipelineAI(LLM, BaseModel):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to Pipeline Cloud endpoint.""" """Call to Pipeline Cloud endpoint."""
try: try:
@ -98,6 +99,7 @@ class PipelineAI(LLM, BaseModel):
) )
client = PipelineCloud(token=self.pipeline_api_key) client = PipelineCloud(token=self.pipeline_api_key)
params = self.pipeline_kwargs or {} params = self.pipeline_kwargs or {}
params = {**params, **kwargs}
run = client.run_pipeline(self.pipeline_key, [prompt, params]) run = client.run_pipeline(self.pipeline_key, [prompt, params])
try: try:

View File

@ -91,6 +91,7 @@ class PredictionGuard(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Prediction Guard's model API. """Call out to Prediction Guard's model API.
Args: Args:
@ -117,6 +118,7 @@ class PredictionGuard(LLM):
output=self.output, output=self.output,
temperature=params["temperature"], temperature=params["temperature"],
max_tokens=params["max_tokens"], max_tokens=params["max_tokens"],
**kwargs,
) )
text = response["choices"][0]["text"] text = response["choices"][0]["text"]

View File

@ -1,6 +1,6 @@
"""PromptLayer wrapper.""" """PromptLayer wrapper."""
import datetime import datetime
from typing import List, Optional from typing import Any, List, Optional
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -42,6 +42,7 @@ class PromptLayerOpenAI(OpenAI):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Call OpenAI generate and then call PromptLayer API to log the request.""" """Call OpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request from promptlayer.utils import get_api_key, promptlayer_api_request
@ -56,11 +57,12 @@ class PromptLayerOpenAI(OpenAI):
"text": generation.text, "text": generation.text,
"llm_output": generated_responses.llm_output, "llm_output": generated_responses.llm_output,
} }
params = {**self._identifying_params, **kwargs}
pl_request_id = promptlayer_api_request( pl_request_id = promptlayer_api_request(
"langchain.PromptLayerOpenAI", "langchain.PromptLayerOpenAI",
"langchain", "langchain",
[prompt], [prompt],
self._identifying_params, params,
self.pl_tags, self.pl_tags,
resp, resp,
request_start_time, request_start_time,
@ -81,6 +83,7 @@ class PromptLayerOpenAI(OpenAI):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
from promptlayer.utils import get_api_key, promptlayer_api_request_async from promptlayer.utils import get_api_key, promptlayer_api_request_async
@ -94,11 +97,12 @@ class PromptLayerOpenAI(OpenAI):
"text": generation.text, "text": generation.text,
"llm_output": generated_responses.llm_output, "llm_output": generated_responses.llm_output,
} }
params = {**self._identifying_params, **kwargs}
pl_request_id = await promptlayer_api_request_async( pl_request_id = await promptlayer_api_request_async(
"langchain.PromptLayerOpenAI.async", "langchain.PromptLayerOpenAI.async",
"langchain", "langchain",
[prompt], [prompt],
self._identifying_params, params,
self.pl_tags, self.pl_tags,
resp, resp,
request_start_time, request_start_time,
@ -147,6 +151,7 @@ class PromptLayerOpenAIChat(OpenAIChat):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Call OpenAI generate and then call PromptLayer API to log the request.""" """Call OpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request from promptlayer.utils import get_api_key, promptlayer_api_request
@ -161,11 +166,12 @@ class PromptLayerOpenAIChat(OpenAIChat):
"text": generation.text, "text": generation.text,
"llm_output": generated_responses.llm_output, "llm_output": generated_responses.llm_output,
} }
params = {**self._identifying_params, **kwargs}
pl_request_id = promptlayer_api_request( pl_request_id = promptlayer_api_request(
"langchain.PromptLayerOpenAIChat", "langchain.PromptLayerOpenAIChat",
"langchain", "langchain",
[prompt], [prompt],
self._identifying_params, params,
self.pl_tags, self.pl_tags,
resp, resp,
request_start_time, request_start_time,
@ -186,6 +192,7 @@ class PromptLayerOpenAIChat(OpenAIChat):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
from promptlayer.utils import get_api_key, promptlayer_api_request_async from promptlayer.utils import get_api_key, promptlayer_api_request_async
@ -199,11 +206,12 @@ class PromptLayerOpenAIChat(OpenAIChat):
"text": generation.text, "text": generation.text,
"llm_output": generated_responses.llm_output, "llm_output": generated_responses.llm_output,
} }
params = {**self._identifying_params, **kwargs}
pl_request_id = await promptlayer_api_request_async( pl_request_id = await promptlayer_api_request_async(
"langchain.PromptLayerOpenAIChat.async", "langchain.PromptLayerOpenAIChat.async",
"langchain", "langchain",
[prompt], [prompt],
self._identifying_params, params,
self.pl_tags, self.pl_tags,
resp, resp,
request_start_time, request_start_time,

View File

@ -85,6 +85,7 @@ class Replicate(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to replicate endpoint.""" """Call to replicate endpoint."""
try: try:
@ -110,6 +111,6 @@ class Replicate(LLM):
first_input_name = input_properties[0][0] first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input} inputs = {first_input_name: prompt, **self.input}
iterator = replicate_python.run(self.model, input={**inputs}) iterator = replicate_python.run(self.model, input={**inputs, **kwargs})
return "".join([output for output in iterator]) return "".join([output for output in iterator])

View File

@ -210,6 +210,7 @@ class RWKV(LLM, BaseModel):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
r"""RWKV generation r"""RWKV generation

View File

@ -207,6 +207,7 @@ class SagemakerEndpoint(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Sagemaker inference endpoint. """Call out to Sagemaker inference endpoint.
@ -223,6 +224,7 @@ class SagemakerEndpoint(LLM):
response = se("Tell me a joke.") response = se("Tell me a joke.")
""" """
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
_model_kwargs = {**_model_kwargs, **kwargs}
_endpoint_kwargs = self.endpoint_kwargs or {} _endpoint_kwargs = self.endpoint_kwargs or {}
body = self.content_handler.transform_input(prompt, _model_kwargs) body = self.content_handler.transform_input(prompt, _model_kwargs)

View File

@ -214,5 +214,8 @@ class SelfHostedPipeline(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop) return self.client(
pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs
)

View File

@ -207,5 +207,8 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop) return self.client(
pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs
)

View File

@ -86,6 +86,7 @@ class StochasticAI(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to StochasticAI's complete endpoint. """Call out to StochasticAI's complete endpoint.
@ -102,6 +103,7 @@ class StochasticAI(LLM):
response = StochasticAI("Tell me a joke.") response = StochasticAI("Tell me a joke.")
""" """
params = self.model_kwargs or {} params = self.model_kwargs or {}
params = {**params, **kwargs}
response_post = requests.post( response_post = requests.post(
url=self.api_url, url=self.api_url,
json={"prompt": prompt, "params": params}, json={"prompt": prompt, "params": params},

View File

@ -50,8 +50,11 @@ class _VertexAICommon(BaseModel):
} }
return {**base_params} return {**base_params}
def _predict(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _predict(
res = self.client.predict(prompt, **self._default_params) self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
params = {**self._default_params, **kwargs}
res = self.client.predict(prompt, **params)
return self._enforce_stop_words(res.text, stop) return self._enforce_stop_words(res.text, stop)
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str: def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
@ -100,6 +103,7 @@ class VertexAI(_VertexAICommon, LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call Vertex model to get predictions based on the prompt. """Call Vertex model to get predictions based on the prompt.
@ -111,4 +115,4 @@ class VertexAI(_VertexAICommon, LLM):
Returns: Returns:
The string generated by the model. The string generated by the model.
""" """
return self._predict(prompt, stop) return self._predict(prompt, stop, **kwargs)

View File

@ -118,6 +118,7 @@ class Writer(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Writer's completions endpoint. """Call out to Writer's completions endpoint.
@ -141,7 +142,7 @@ class Writer(LLM):
f"/organization/{self.writer_org_id}" f"/organization/{self.writer_org_id}"
f"/model/{self.model_id}/completions" f"/model/{self.model_id}/completions"
) )
params = {**self._default_params, **kwargs}
response = requests.post( response = requests.post(
url=base_url, url=base_url,
headers={ headers={
@ -149,7 +150,7 @@ class Writer(LLM):
"Content-Type": "application/json", "Content-Type": "application/json",
"Accept": "application/json", "Accept": "application/json",
}, },
json={"prompt": prompt, **self._default_params}, json={"prompt": prompt, **params},
) )
text = response.text text = response.text
if stop is not None: if stop is not None:

View File

@ -20,6 +20,7 @@ class FakeListLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Increment counter, and then return response in that index.""" """Increment counter, and then return response in that index."""
self.i += 1 self.i += 1

View File

@ -38,6 +38,7 @@ class FakeListLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Increment counter, and then return response in that index.""" """Increment counter, and then return response in that index."""
self.i += 1 self.i += 1

View File

@ -1,5 +1,5 @@
"""Test HyDE.""" """Test HyDE."""
from typing import List, Optional from typing import Any, List, Optional
import numpy as np import numpy as np
@ -36,6 +36,7 @@ class FakeLLM(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
@ -44,6 +45,7 @@ class FakeLLM(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])

View File

@ -15,6 +15,7 @@ class FakeLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Return `foo` if longer than 10000 words, else `bar`.""" """Return `foo` if longer than 10000 words, else `bar`."""
if len(prompt) > 10000: if len(prompt) > 10000:

View File

@ -17,6 +17,7 @@ class FakeChatModel(SimpleChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
return "fake response" return "fake response"
@ -25,6 +26,7 @@ class FakeChatModel(SimpleChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
output_str = "fake response" output_str = "fake response"
message = AIMessage(content=output_str) message = AIMessage(content=output_str)

View File

@ -34,6 +34,7 @@ class FakeLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
if self.sequential_responses: if self.sequential_responses:
return self._get_next_response_in_sequence return self._get_next_response_in_sequence