mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-22 15:38:06 +00:00
support kwargs (#5990)
This commit is contained in:
parent
b934677a81
commit
704d56e241
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Sequence, Set
|
from typing import Any, List, Optional, Sequence, Set
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -36,6 +36,7 @@ class BaseLanguageModel(BaseModel, ABC):
|
|||||||
prompts: List[PromptValue],
|
prompts: List[PromptValue],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Take in a list of prompt values and return an LLMResult."""
|
"""Take in a list of prompt values and return an LLMResult."""
|
||||||
|
|
||||||
@ -45,26 +46,39 @@ class BaseLanguageModel(BaseModel, ABC):
|
|||||||
prompts: List[PromptValue],
|
prompts: List[PromptValue],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Take in a list of prompt values and return an LLMResult."""
|
"""Take in a list of prompt values and return an LLMResult."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
def predict(
|
||||||
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||||
|
) -> str:
|
||||||
"""Predict text from text."""
|
"""Predict text from text."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict_messages(
|
def predict_messages(
|
||||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
*,
|
||||||
|
stop: Optional[Sequence[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
"""Predict message from messages."""
|
"""Predict message from messages."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
async def apredict(
|
||||||
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||||
|
) -> str:
|
||||||
"""Predict text from text."""
|
"""Predict text from text."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def apredict_messages(
|
async def apredict_messages(
|
||||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
*,
|
||||||
|
stop: Optional[Sequence[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
"""Predict message from messages."""
|
"""Predict message from messages."""
|
||||||
|
|
||||||
|
@ -94,9 +94,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
prompt = self._convert_messages_to_prompt(messages)
|
prompt = self._convert_messages_to_prompt(messages)
|
||||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
|
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||||
if stop:
|
if stop:
|
||||||
params["stop_sequences"] = stop
|
params["stop_sequences"] = stop
|
||||||
|
|
||||||
@ -121,9 +122,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
prompt = self._convert_messages_to_prompt(messages)
|
prompt = self._convert_messages_to_prompt(messages)
|
||||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
|
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||||
if stop:
|
if stop:
|
||||||
params["stop_sequences"] = stop
|
params["stop_sequences"] = stop
|
||||||
|
|
||||||
|
@ -64,6 +64,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
messages: List[List[BaseMessage]],
|
messages: List[List[BaseMessage]],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
|
|
||||||
@ -82,7 +83,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
results = [
|
results = [
|
||||||
self._generate(m, stop=stop, run_manager=run_manager)
|
self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else self._generate(m, stop=stop)
|
else self._generate(m, stop=stop)
|
||||||
for m in messages
|
for m in messages
|
||||||
@ -103,6 +104,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
messages: List[List[BaseMessage]],
|
messages: List[List[BaseMessage]],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
@ -121,7 +123,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
try:
|
try:
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
self._agenerate(m, stop=stop, run_manager=run_manager)
|
self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else self._agenerate(m, stop=stop)
|
else self._agenerate(m, stop=stop)
|
||||||
for m in messages
|
for m in messages
|
||||||
@ -143,18 +145,22 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
prompts: List[PromptValue],
|
prompts: List[PromptValue],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
prompt_messages = [p.to_messages() for p in prompts]
|
prompt_messages = [p.to_messages() for p in prompts]
|
||||||
return self.generate(prompt_messages, stop=stop, callbacks=callbacks)
|
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
|
||||||
|
|
||||||
async def agenerate_prompt(
|
async def agenerate_prompt(
|
||||||
self,
|
self,
|
||||||
prompts: List[PromptValue],
|
prompts: List[PromptValue],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
prompt_messages = [p.to_messages() for p in prompts]
|
prompt_messages = [p.to_messages() for p in prompts]
|
||||||
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks)
|
return await self.agenerate(
|
||||||
|
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _generate(
|
def _generate(
|
||||||
@ -162,6 +168,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
|
|
||||||
@ -171,6 +178,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
|
|
||||||
@ -193,18 +201,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
result = await self.agenerate([messages], stop=stop, callbacks=callbacks)
|
result = await self.agenerate(
|
||||||
|
[messages], stop=stop, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
generation = result.generations[0][0]
|
generation = result.generations[0][0]
|
||||||
if isinstance(generation, ChatGeneration):
|
if isinstance(generation, ChatGeneration):
|
||||||
return generation.message
|
return generation.message
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unexpected generation type")
|
raise ValueError("Unexpected generation type")
|
||||||
|
|
||||||
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
|
def call_as_llm(
|
||||||
return self.predict(message, stop=stop)
|
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||||
|
) -> str:
|
||||||
|
return self.predict(message, stop=stop, **kwargs)
|
||||||
|
|
||||||
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
def predict(
|
||||||
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||||
|
) -> str:
|
||||||
if stop is None:
|
if stop is None:
|
||||||
_stop = None
|
_stop = None
|
||||||
else:
|
else:
|
||||||
@ -213,30 +228,42 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
return result.content
|
return result.content
|
||||||
|
|
||||||
def predict_messages(
|
def predict_messages(
|
||||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
*,
|
||||||
|
stop: Optional[Sequence[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
if stop is None:
|
if stop is None:
|
||||||
_stop = None
|
_stop = None
|
||||||
else:
|
else:
|
||||||
_stop = list(stop)
|
_stop = list(stop)
|
||||||
return self(messages, stop=_stop)
|
return self(messages, stop=_stop, **kwargs)
|
||||||
|
|
||||||
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
async def apredict(
|
||||||
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||||
|
) -> str:
|
||||||
if stop is None:
|
if stop is None:
|
||||||
_stop = None
|
_stop = None
|
||||||
else:
|
else:
|
||||||
_stop = list(stop)
|
_stop = list(stop)
|
||||||
result = await self._call_async([HumanMessage(content=text)], stop=_stop)
|
result = await self._call_async(
|
||||||
|
[HumanMessage(content=text)], stop=_stop, **kwargs
|
||||||
|
)
|
||||||
return result.content
|
return result.content
|
||||||
|
|
||||||
async def apredict_messages(
|
async def apredict_messages(
|
||||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
*,
|
||||||
|
stop: Optional[Sequence[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
if stop is None:
|
if stop is None:
|
||||||
_stop = None
|
_stop = None
|
||||||
else:
|
else:
|
||||||
_stop = list(stop)
|
_stop = list(stop)
|
||||||
return await self._call_async(messages, stop=_stop)
|
return await self._call_async(messages, stop=_stop, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
@ -261,8 +288,9 @@ class SimpleChatModel(BaseChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
output_str = self._call(messages, stop=stop, run_manager=run_manager)
|
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
message = AIMessage(content=output_str)
|
message = AIMessage(content=output_str)
|
||||||
generation = ChatGeneration(message=message)
|
generation = ChatGeneration(message=message)
|
||||||
return ChatResult(generations=[generation])
|
return ChatResult(generations=[generation])
|
||||||
@ -273,6 +301,7 @@ class SimpleChatModel(BaseChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Simpler interface."""
|
"""Simpler interface."""
|
||||||
|
|
||||||
@ -281,6 +310,9 @@ class SimpleChatModel(BaseChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
func = partial(self._generate, messages, stop=stop, run_manager=run_manager)
|
func = partial(
|
||||||
|
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||||
|
@ -280,6 +280,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
prompt = _messages_to_prompt_dict(messages)
|
prompt = _messages_to_prompt_dict(messages)
|
||||||
|
|
||||||
@ -291,6 +292,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
|||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
candidate_count=self.n,
|
candidate_count=self.n,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _response_to_result(response, stop)
|
return _response_to_result(response, stop)
|
||||||
@ -300,6 +302,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
prompt = _messages_to_prompt_dict(messages)
|
prompt = _messages_to_prompt_dict(messages)
|
||||||
|
|
||||||
|
@ -302,8 +302,10 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
inner_completion = ""
|
inner_completion = ""
|
||||||
role = "assistant"
|
role = "assistant"
|
||||||
@ -348,8 +350,10 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
inner_completion = ""
|
inner_completion = ""
|
||||||
role = "assistant"
|
role = "assistant"
|
||||||
|
@ -42,6 +42,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Call ChatOpenAI generate and then call PromptLayer API to log the request."""
|
"""Call ChatOpenAI generate and then call PromptLayer API to log the request."""
|
||||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||||
@ -54,6 +55,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
|||||||
response_dict, params = super()._create_message_dicts(
|
response_dict, params = super()._create_message_dicts(
|
||||||
[generation.message], stop
|
[generation.message], stop
|
||||||
)
|
)
|
||||||
|
params = {**params, **kwargs}
|
||||||
pl_request_id = promptlayer_api_request(
|
pl_request_id = promptlayer_api_request(
|
||||||
"langchain.PromptLayerChatOpenAI",
|
"langchain.PromptLayerChatOpenAI",
|
||||||
"langchain",
|
"langchain",
|
||||||
@ -79,6 +81,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Call ChatOpenAI agenerate and then call PromptLayer to log."""
|
"""Call ChatOpenAI agenerate and then call PromptLayer to log."""
|
||||||
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
||||||
@ -91,6 +94,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
|||||||
response_dict, params = super()._create_message_dicts(
|
response_dict, params = super()._create_message_dicts(
|
||||||
[generation.message], stop
|
[generation.message], stop
|
||||||
)
|
)
|
||||||
|
params = {**params, **kwargs}
|
||||||
pl_request_id = await promptlayer_api_request_async(
|
pl_request_id = await promptlayer_api_request_async(
|
||||||
"langchain.PromptLayerChatOpenAI.async",
|
"langchain.PromptLayerChatOpenAI.async",
|
||||||
"langchain",
|
"langchain",
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Wrapper around Google VertexAI chat-based models."""
|
"""Wrapper around Google VertexAI chat-based models."""
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
@ -93,6 +93,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Generate next turn in the conversation.
|
"""Generate next turn in the conversation.
|
||||||
|
|
||||||
@ -119,7 +120,8 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
|
|
||||||
history = _parse_chat_history(messages[:-1])
|
history = _parse_chat_history(messages[:-1])
|
||||||
context = history.system_message.content if history.system_message else None
|
context = history.system_message.content if history.system_message else None
|
||||||
chat = self.client.start_chat(context=context, **self._default_params)
|
params = {**self._default_params, **kwargs}
|
||||||
|
chat = self.client.start_chat(context=context, **params)
|
||||||
for pair in history.history:
|
for pair in history.history:
|
||||||
chat._history.append((pair.question.content, pair.answer.content))
|
chat._history.append((pair.question.content, pair.answer.content))
|
||||||
response = chat.send_message(question.content, **self._default_params)
|
response = chat.send_message(question.content, **self._default_params)
|
||||||
@ -131,6 +133,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"""Vertex AI doesn't support async requests at the moment."""
|
"""Vertex AI doesn't support async requests at the moment."""
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, List, Optional, cast
|
from typing import TYPE_CHECKING, Any, List, Optional, cast
|
||||||
|
|
||||||
from pydantic import Field, root_validator
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
@ -42,6 +42,7 @@ class JsonFormer(HuggingFacePipeline):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
jsonformer = import_jsonformer()
|
jsonformer = import_jsonformer()
|
||||||
from transformers import Text2TextGenerationPipeline
|
from transformers import Text2TextGenerationPipeline
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Experimental implementation of RELLM wrapped LLM."""
|
"""Experimental implementation of RELLM wrapped LLM."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Optional, cast
|
from typing import TYPE_CHECKING, Any, List, Optional, cast
|
||||||
|
|
||||||
from pydantic import Field, root_validator
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
@ -47,6 +47,7 @@ class RELLM(HuggingFacePipeline):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
rellm = import_rellm()
|
rellm = import_rellm()
|
||||||
from transformers import Text2TextGenerationPipeline
|
from transformers import Text2TextGenerationPipeline
|
||||||
|
@ -112,6 +112,7 @@ class AI21(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to AI21's complete endpoint.
|
"""Call out to AI21's complete endpoint.
|
||||||
|
|
||||||
@ -140,10 +141,11 @@ class AI21(LLM):
|
|||||||
base_url = "https://api.ai21.com/studio/v1/experimental"
|
base_url = "https://api.ai21.com/studio/v1/experimental"
|
||||||
else:
|
else:
|
||||||
base_url = "https://api.ai21.com/studio/v1"
|
base_url = "https://api.ai21.com/studio/v1"
|
||||||
|
params = {**self._default_params, **kwargs}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=f"{base_url}/{self.model}/complete",
|
url=f"{base_url}/{self.model}/complete",
|
||||||
headers={"Authorization": f"Bearer {self.ai21_api_key}"},
|
headers={"Authorization": f"Bearer {self.ai21_api_key}"},
|
||||||
json={"prompt": prompt, "stopSequences": stop, **self._default_params},
|
json={"prompt": prompt, "stopSequences": stop, **params},
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
optional_detail = response.json().get("error")
|
optional_detail = response.json().get("error")
|
||||||
|
@ -206,6 +206,7 @@ class AlephAlpha(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to Aleph Alpha's completion endpoint.
|
"""Call out to Aleph Alpha's completion endpoint.
|
||||||
|
|
||||||
@ -232,6 +233,7 @@ class AlephAlpha(LLM):
|
|||||||
params["stop_sequences"] = self.stop_sequences
|
params["stop_sequences"] = self.stop_sequences
|
||||||
else:
|
else:
|
||||||
params["stop_sequences"] = stop
|
params["stop_sequences"] = stop
|
||||||
|
params = {**params, **kwargs}
|
||||||
request = CompletionRequest(prompt=Prompt.from_text(prompt), **params)
|
request = CompletionRequest(prompt=Prompt.from_text(prompt), **params)
|
||||||
response = self.client.complete(model=self.model, request=request)
|
response = self.client.complete(model=self.model, request=request)
|
||||||
text = response.completions[0].completion
|
text = response.completions[0].completion
|
||||||
|
@ -162,6 +162,7 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
r"""Call out to Anthropic's completion endpoint.
|
r"""Call out to Anthropic's completion endpoint.
|
||||||
|
|
||||||
@ -181,11 +182,12 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
stop = self._get_anthropic_stop(stop)
|
stop = self._get_anthropic_stop(stop)
|
||||||
|
params = {**self._default_params, **kwargs}
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
stream_resp = self.client.completion_stream(
|
stream_resp = self.client.completion_stream(
|
||||||
prompt=self._wrap_prompt(prompt),
|
prompt=self._wrap_prompt(prompt),
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
**self._default_params,
|
**params,
|
||||||
)
|
)
|
||||||
current_completion = ""
|
current_completion = ""
|
||||||
for data in stream_resp:
|
for data in stream_resp:
|
||||||
@ -197,7 +199,7 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
response = self.client.completion(
|
response = self.client.completion(
|
||||||
prompt=self._wrap_prompt(prompt),
|
prompt=self._wrap_prompt(prompt),
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
**self._default_params,
|
**params,
|
||||||
)
|
)
|
||||||
return response["completion"]
|
return response["completion"]
|
||||||
|
|
||||||
@ -206,14 +208,16 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to Anthropic's completion endpoint asynchronously."""
|
"""Call out to Anthropic's completion endpoint asynchronously."""
|
||||||
stop = self._get_anthropic_stop(stop)
|
stop = self._get_anthropic_stop(stop)
|
||||||
|
params = {**self._default_params, **kwargs}
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
stream_resp = await self.client.acompletion_stream(
|
stream_resp = await self.client.acompletion_stream(
|
||||||
prompt=self._wrap_prompt(prompt),
|
prompt=self._wrap_prompt(prompt),
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
**self._default_params,
|
**params,
|
||||||
)
|
)
|
||||||
current_completion = ""
|
current_completion = ""
|
||||||
async for data in stream_resp:
|
async for data in stream_resp:
|
||||||
@ -225,7 +229,7 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
response = await self.client.acompletion(
|
response = await self.client.acompletion(
|
||||||
prompt=self._wrap_prompt(prompt),
|
prompt=self._wrap_prompt(prompt),
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
**self._default_params,
|
**params,
|
||||||
)
|
)
|
||||||
return response["completion"]
|
return response["completion"]
|
||||||
|
|
||||||
|
@ -88,6 +88,7 @@ class Anyscale(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to Anyscale Service endpoint.
|
"""Call out to Anyscale Service endpoint.
|
||||||
Args:
|
Args:
|
||||||
|
@ -105,6 +105,7 @@ class Aviary(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to Aviary
|
"""Call out to Aviary
|
||||||
Args:
|
Args:
|
||||||
|
@ -87,6 +87,7 @@ class Banana(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call to Banana endpoint."""
|
"""Call to Banana endpoint."""
|
||||||
try:
|
try:
|
||||||
@ -97,6 +98,7 @@ class Banana(LLM):
|
|||||||
"Please install it with `pip install banana-dev`."
|
"Please install it with `pip install banana-dev`."
|
||||||
)
|
)
|
||||||
params = self.model_kwargs or {}
|
params = self.model_kwargs or {}
|
||||||
|
params = {**params, **kwargs}
|
||||||
api_key = self.banana_api_key
|
api_key = self.banana_api_key
|
||||||
model_key = self.model_key
|
model_key = self.model_key
|
||||||
model_inputs = {
|
model_inputs = {
|
||||||
|
@ -113,6 +113,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompts."""
|
"""Run the LLM on the given prompts."""
|
||||||
|
|
||||||
@ -122,6 +123,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompts."""
|
"""Run the LLM on the given prompts."""
|
||||||
|
|
||||||
@ -130,24 +132,29 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
prompts: List[PromptValue],
|
prompts: List[PromptValue],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
prompt_strings = [p.to_string() for p in prompts]
|
prompt_strings = [p.to_string() for p in prompts]
|
||||||
return self.generate(prompt_strings, stop=stop, callbacks=callbacks)
|
return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
|
||||||
|
|
||||||
async def agenerate_prompt(
|
async def agenerate_prompt(
|
||||||
self,
|
self,
|
||||||
prompts: List[PromptValue],
|
prompts: List[PromptValue],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
prompt_strings = [p.to_string() for p in prompts]
|
prompt_strings = [p.to_string() for p in prompts]
|
||||||
return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks)
|
return await self.agenerate(
|
||||||
|
prompt_strings, stop=stop, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
# If string is passed in directly no errors will be raised but outputs will
|
# If string is passed in directly no errors will be raised but outputs will
|
||||||
@ -183,9 +190,11 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
output = (
|
output = (
|
||||||
self._generate(prompts, stop=stop, run_manager=run_manager)
|
self._generate(
|
||||||
|
prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else self._generate(prompts, stop=stop)
|
else self._generate(prompts, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
run_manager.on_llm_error(e)
|
run_manager.on_llm_error(e)
|
||||||
@ -202,9 +211,11 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
new_results = (
|
new_results = (
|
||||||
self._generate(missing_prompts, stop=stop, run_manager=run_manager)
|
self._generate(
|
||||||
|
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else self._generate(missing_prompts, stop=stop)
|
else self._generate(missing_prompts, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
run_manager.on_llm_error(e)
|
run_manager.on_llm_error(e)
|
||||||
@ -227,6 +238,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
@ -255,9 +267,11 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
output = (
|
output = (
|
||||||
await self._agenerate(prompts, stop=stop, run_manager=run_manager)
|
await self._agenerate(
|
||||||
|
prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else await self._agenerate(prompts, stop=stop)
|
else await self._agenerate(prompts, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
await run_manager.on_llm_error(e, verbose=self.verbose)
|
await run_manager.on_llm_error(e, verbose=self.verbose)
|
||||||
@ -275,10 +289,10 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
try:
|
try:
|
||||||
new_results = (
|
new_results = (
|
||||||
await self._agenerate(
|
await self._agenerate(
|
||||||
missing_prompts, stop=stop, run_manager=run_manager
|
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else await self._agenerate(missing_prompts, stop=stop)
|
else await self._agenerate(missing_prompts, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
await run_manager.on_llm_error(e)
|
await run_manager.on_llm_error(e)
|
||||||
@ -297,7 +311,11 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Check Cache and run the LLM on the given prompt and input."""
|
"""Check Cache and run the LLM on the given prompt and input."""
|
||||||
if not isinstance(prompt, str):
|
if not isinstance(prompt, str):
|
||||||
@ -307,52 +325,70 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
"`generate` instead."
|
"`generate` instead."
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
self.generate([prompt], stop=stop, callbacks=callbacks)
|
self.generate([prompt], stop=stop, callbacks=callbacks, **kwargs)
|
||||||
.generations[0][0]
|
.generations[0][0]
|
||||||
.text
|
.text
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _call_async(
|
async def _call_async(
|
||||||
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Check Cache and run the LLM on the given prompt and input."""
|
"""Check Cache and run the LLM on the given prompt and input."""
|
||||||
result = await self.agenerate([prompt], stop=stop, callbacks=callbacks)
|
result = await self.agenerate(
|
||||||
|
[prompt], stop=stop, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
return result.generations[0][0].text
|
return result.generations[0][0].text
|
||||||
|
|
||||||
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
def predict(
|
||||||
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||||
|
) -> str:
|
||||||
if stop is None:
|
if stop is None:
|
||||||
_stop = None
|
_stop = None
|
||||||
else:
|
else:
|
||||||
_stop = list(stop)
|
_stop = list(stop)
|
||||||
return self(text, stop=_stop)
|
return self(text, stop=_stop, **kwargs)
|
||||||
|
|
||||||
def predict_messages(
|
def predict_messages(
|
||||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
*,
|
||||||
|
stop: Optional[Sequence[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
text = get_buffer_string(messages)
|
text = get_buffer_string(messages)
|
||||||
if stop is None:
|
if stop is None:
|
||||||
_stop = None
|
_stop = None
|
||||||
else:
|
else:
|
||||||
_stop = list(stop)
|
_stop = list(stop)
|
||||||
content = self(text, stop=_stop)
|
content = self(text, stop=_stop, **kwargs)
|
||||||
return AIMessage(content=content)
|
return AIMessage(content=content)
|
||||||
|
|
||||||
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
async def apredict(
|
||||||
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||||
|
) -> str:
|
||||||
if stop is None:
|
if stop is None:
|
||||||
_stop = None
|
_stop = None
|
||||||
else:
|
else:
|
||||||
_stop = list(stop)
|
_stop = list(stop)
|
||||||
return await self._call_async(text, stop=_stop)
|
return await self._call_async(text, stop=_stop, **kwargs)
|
||||||
|
|
||||||
async def apredict_messages(
|
async def apredict_messages(
|
||||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
*,
|
||||||
|
stop: Optional[Sequence[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
text = get_buffer_string(messages)
|
text = get_buffer_string(messages)
|
||||||
if stop is None:
|
if stop is None:
|
||||||
_stop = None
|
_stop = None
|
||||||
else:
|
else:
|
||||||
_stop = list(stop)
|
_stop = list(stop)
|
||||||
content = await self._call_async(text, stop=_stop)
|
content = await self._call_async(text, stop=_stop, **kwargs)
|
||||||
return AIMessage(content=content)
|
return AIMessage(content=content)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -422,6 +458,7 @@ class LLM(BaseLLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
|
|
||||||
@ -430,6 +467,7 @@ class LLM(BaseLLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
raise NotImplementedError("Async generation not implemented for this LLM.")
|
raise NotImplementedError("Async generation not implemented for this LLM.")
|
||||||
@ -439,6 +477,7 @@ class LLM(BaseLLM):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
# TODO: add caching here.
|
# TODO: add caching here.
|
||||||
@ -446,9 +485,9 @@ class LLM(BaseLLM):
|
|||||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
text = (
|
text = (
|
||||||
self._call(prompt, stop=stop, run_manager=run_manager)
|
self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else self._call(prompt, stop=stop)
|
else self._call(prompt, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
generations.append([Generation(text=text)])
|
generations.append([Generation(text=text)])
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
@ -458,15 +497,16 @@ class LLM(BaseLLM):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
generations = []
|
generations = []
|
||||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
text = (
|
text = (
|
||||||
await self._acall(prompt, stop=stop, run_manager=run_manager)
|
await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else await self._acall(prompt, stop=stop)
|
else await self._acall(prompt, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
generations.append([Generation(text=text)])
|
generations.append([Generation(text=text)])
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
@ -54,6 +54,7 @@ class Baseten(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call to Baseten deployed model endpoint."""
|
"""Call to Baseten deployed model endpoint."""
|
||||||
try:
|
try:
|
||||||
|
@ -251,10 +251,12 @@ class Beam(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[list] = None,
|
stop: Optional[list] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call to Beam."""
|
"""Call to Beam."""
|
||||||
url = "https://apps.beam.cloud/" + self.app_id if self.app_id else self.url
|
url = "https://apps.beam.cloud/" + self.app_id if self.app_id else self.url
|
||||||
payload = {"prompt": prompt, "max_length": self.max_length}
|
payload = {"prompt": prompt, "max_length": self.max_length}
|
||||||
|
payload.update(kwargs)
|
||||||
headers = {
|
headers = {
|
||||||
"Accept": "*/*",
|
"Accept": "*/*",
|
||||||
"Accept-Encoding": "gzip, deflate",
|
"Accept-Encoding": "gzip, deflate",
|
||||||
|
@ -155,6 +155,7 @@ class Bedrock(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to Bedrock service model.
|
"""Call out to Bedrock service model.
|
||||||
|
|
||||||
@ -173,10 +174,8 @@ class Bedrock(LLM):
|
|||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
|
||||||
provider = self.model_id.split(".")[0]
|
provider = self.model_id.split(".")[0]
|
||||||
|
params = {**_model_kwargs, **kwargs}
|
||||||
input_body = LLMInputOutputAdapter.prepare_input(
|
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||||
provider, prompt, _model_kwargs
|
|
||||||
)
|
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
accept = "application/json"
|
accept = "application/json"
|
||||||
contentType = "application/json"
|
contentType = "application/json"
|
||||||
|
@ -88,6 +88,7 @@ class CerebriumAI(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call to CerebriumAI endpoint."""
|
"""Call to CerebriumAI endpoint."""
|
||||||
try:
|
try:
|
||||||
@ -100,7 +101,9 @@ class CerebriumAI(LLM):
|
|||||||
|
|
||||||
params = self.model_kwargs or {}
|
params = self.model_kwargs or {}
|
||||||
response = model_api_request(
|
response = model_api_request(
|
||||||
self.endpoint_url, {"prompt": prompt, **params}, self.cerebriumai_api_key
|
self.endpoint_url,
|
||||||
|
{"prompt": prompt, **params, **kwargs},
|
||||||
|
self.cerebriumai_api_key,
|
||||||
)
|
)
|
||||||
text = response["data"]["result"]
|
text = response["data"]["result"]
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
|
@ -145,6 +145,7 @@ class Cohere(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to Cohere's generate endpoint.
|
"""Call out to Cohere's generate endpoint.
|
||||||
|
|
||||||
@ -167,7 +168,7 @@ class Cohere(LLM):
|
|||||||
params["stop_sequences"] = self.stop
|
params["stop_sequences"] = self.stop
|
||||||
else:
|
else:
|
||||||
params["stop_sequences"] = stop
|
params["stop_sequences"] = stop
|
||||||
|
params = {**params, **kwargs}
|
||||||
response = completion_with_retry(
|
response = completion_with_retry(
|
||||||
self, model=self.model, prompt=prompt, **params
|
self, model=self.model, prompt=prompt, **params
|
||||||
)
|
)
|
||||||
|
@ -81,6 +81,7 @@ class CTransformers(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate text from a prompt.
|
"""Generate text from a prompt.
|
||||||
|
|
||||||
|
@ -303,12 +303,14 @@ class Databricks(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Queries the LLM endpoint with the given prompt and stop sequence."""
|
"""Queries the LLM endpoint with the given prompt and stop sequence."""
|
||||||
|
|
||||||
# TODO: support callbacks
|
# TODO: support callbacks
|
||||||
|
|
||||||
request = {"prompt": prompt, "stop": stop}
|
request = {"prompt": prompt, "stop": stop}
|
||||||
|
request.update(kwargs)
|
||||||
if self.model_kwargs:
|
if self.model_kwargs:
|
||||||
request.update(self.model_kwargs)
|
request.update(self.model_kwargs)
|
||||||
|
|
||||||
|
@ -66,6 +66,7 @@ class DeepInfra(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to DeepInfra's inference API endpoint.
|
"""Call out to DeepInfra's inference API endpoint.
|
||||||
|
|
||||||
@ -82,6 +83,7 @@ class DeepInfra(LLM):
|
|||||||
response = di("Tell me a joke.")
|
response = di("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
_model_kwargs = {**_model_kwargs, **kwargs}
|
||||||
# HTTP headers for authorization
|
# HTTP headers for authorization
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"bearer {self.deepinfra_api_token}",
|
"Authorization": f"bearer {self.deepinfra_api_token}",
|
||||||
|
@ -24,6 +24,7 @@ class FakeListLLM(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return next response"""
|
"""Return next response"""
|
||||||
response = self.responses[self.i]
|
response = self.responses[self.i]
|
||||||
@ -35,6 +36,7 @@ class FakeListLLM(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return next response"""
|
"""Return next response"""
|
||||||
response = self.responses[self.i]
|
response = self.responses[self.i]
|
||||||
|
@ -87,6 +87,7 @@ class ForefrontAI(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to ForefrontAI's complete endpoint.
|
"""Call out to ForefrontAI's complete endpoint.
|
||||||
|
|
||||||
@ -108,7 +109,7 @@ class ForefrontAI(LLM):
|
|||||||
"Authorization": f"Bearer {self.forefrontai_api_key}",
|
"Authorization": f"Bearer {self.forefrontai_api_key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
json={"text": prompt, **self._default_params},
|
json={"text": prompt, **self._default_params, **kwargs},
|
||||||
)
|
)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
text = response_json["result"][0]["completion"]
|
text = response_json["result"][0]["completion"]
|
||||||
|
@ -134,6 +134,7 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
generations = []
|
generations = []
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
@ -147,6 +148,7 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
max_output_tokens=self.max_output_tokens,
|
max_output_tokens=self.max_output_tokens,
|
||||||
candidate_count=self.n,
|
candidate_count=self.n,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_generations = []
|
prompt_generations = []
|
||||||
@ -163,6 +165,7 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -137,6 +137,7 @@ class GooseAI(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call the GooseAI API."""
|
"""Call the GooseAI API."""
|
||||||
params = self._default_params
|
params = self._default_params
|
||||||
@ -145,6 +146,8 @@ class GooseAI(LLM):
|
|||||||
raise ValueError("`stop` found in both the input and default params.")
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
|
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
|
||||||
response = self.client.create(engine=self.model_name, prompt=prompt, **params)
|
response = self.client.create(engine=self.model_name, prompt=prompt, **params)
|
||||||
text = response.choices[0].text
|
text = response.choices[0].text
|
||||||
return text
|
return text
|
||||||
|
@ -183,6 +183,7 @@ class GPT4All(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
r"""Call out to GPT4All's generate method.
|
r"""Call out to GPT4All's generate method.
|
||||||
|
|
||||||
@ -203,7 +204,8 @@ class GPT4All(LLM):
|
|||||||
if run_manager:
|
if run_manager:
|
||||||
text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)
|
text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)
|
||||||
text = ""
|
text = ""
|
||||||
for token in self.client.generate(prompt, **self._default_params()):
|
params = {**self._default_params(), **kwargs}
|
||||||
|
for token in self.client.generate(prompt, **params):
|
||||||
if text_callback:
|
if text_callback:
|
||||||
text_callback(token)
|
text_callback(token)
|
||||||
text += token
|
text += token
|
||||||
|
@ -96,6 +96,7 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to HuggingFace Hub's inference endpoint.
|
"""Call out to HuggingFace Hub's inference endpoint.
|
||||||
|
|
||||||
@ -114,7 +115,8 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
|
||||||
# payload samples
|
# payload samples
|
||||||
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
params = {**_model_kwargs, **kwargs}
|
||||||
|
parameter_payload = {"inputs": prompt, "parameters": params}
|
||||||
|
|
||||||
# HTTP headers for authorization
|
# HTTP headers for authorization
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -91,6 +91,7 @@ class HuggingFaceHub(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to HuggingFace Hub's inference endpoint.
|
"""Call out to HuggingFace Hub's inference endpoint.
|
||||||
|
|
||||||
@ -107,7 +108,8 @@ class HuggingFaceHub(LLM):
|
|||||||
response = hf("Tell me a joke.")
|
response = hf("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
response = self.client(inputs=prompt, params=_model_kwargs)
|
params = {**_model_kwargs, **kwargs}
|
||||||
|
response = self.client(inputs=prompt, params=params)
|
||||||
if "error" in response:
|
if "error" in response:
|
||||||
raise ValueError(f"Error raised by inference API: {response['error']}")
|
raise ValueError(f"Error raised by inference API: {response['error']}")
|
||||||
if self.client.task == "text-generation":
|
if self.client.task == "text-generation":
|
||||||
|
@ -164,6 +164,7 @@ class HuggingFacePipeline(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
response = self.pipeline(prompt)
|
response = self.pipeline(prompt)
|
||||||
if self.pipeline.task == "text-generation":
|
if self.pipeline.task == "text-generation":
|
||||||
|
@ -113,6 +113,7 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
if stop is None:
|
if stop is None:
|
||||||
stop = self.stop_sequences
|
stop = self.stop_sequences
|
||||||
@ -130,6 +131,7 @@ class HuggingFaceTextGenInference(LLM):
|
|||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
repetition_penalty=self.repetition_penalty,
|
repetition_penalty=self.repetition_penalty,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# remove stop sequences from the end of the generated text
|
# remove stop sequences from the end of the generated text
|
||||||
for stop_seq in stop:
|
for stop_seq in stop:
|
||||||
|
@ -60,6 +60,7 @@ class HumanInputLLM(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Displays the prompt to the user and returns their input as a response.
|
Displays the prompt to the user and returns their input as a response.
|
||||||
|
@ -200,6 +200,7 @@ class LlamaCpp(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call the Llama model and return the output.
|
"""Call the Llama model and return the output.
|
||||||
|
|
||||||
@ -227,6 +228,7 @@ class LlamaCpp(LLM):
|
|||||||
return combined_text_output
|
return combined_text_output
|
||||||
else:
|
else:
|
||||||
params = self._get_parameters(stop)
|
params = self._get_parameters(stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
result = self.client(prompt=prompt, **params)
|
result = self.client(prompt=prompt, **params)
|
||||||
return result["choices"][0]["text"]
|
return result["choices"][0]["text"]
|
||||||
|
|
||||||
|
@ -48,13 +48,15 @@ class ManifestWrapper(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to LLM through Manifest."""
|
"""Call out to LLM through Manifest."""
|
||||||
if stop is not None and len(stop) != 1:
|
if stop is not None and len(stop) != 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Manifest currently only supports a single stop token, got {stop}"
|
f"Manifest currently only supports a single stop token, got {stop}"
|
||||||
)
|
)
|
||||||
kwargs = self.llm_kwargs or {}
|
params = self.llm_kwargs or {}
|
||||||
|
params = {**params, **kwargs}
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
kwargs["stop_token"] = stop
|
params["stop_token"] = stop
|
||||||
return self.client.run(prompt, **kwargs)
|
return self.client.run(prompt, **params)
|
||||||
|
@ -76,9 +76,11 @@ class Modal(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call to Modal endpoint."""
|
"""Call to Modal endpoint."""
|
||||||
params = self.model_kwargs or {}
|
params = self.model_kwargs or {}
|
||||||
|
params = {**params, **kwargs}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=self.endpoint_url,
|
url=self.endpoint_url,
|
||||||
headers={
|
headers={
|
||||||
|
@ -102,6 +102,7 @@ class MosaicML(LLM):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
is_retry: bool = False,
|
is_retry: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to a MosaicML LLM inference endpoint.
|
"""Call out to a MosaicML LLM inference endpoint.
|
||||||
|
|
||||||
@ -123,6 +124,7 @@ class MosaicML(LLM):
|
|||||||
|
|
||||||
payload = {"input_strings": [prompt]}
|
payload = {"input_strings": [prompt]}
|
||||||
payload.update(_model_kwargs)
|
payload.update(_model_kwargs)
|
||||||
|
payload.update(kwargs)
|
||||||
|
|
||||||
# HTTP headers for authorization
|
# HTTP headers for authorization
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -117,6 +117,7 @@ class NLPCloud(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to NLPCloud's create endpoint.
|
"""Call out to NLPCloud's create endpoint.
|
||||||
|
|
||||||
@ -141,7 +142,6 @@ class NLPCloud(LLM):
|
|||||||
end_sequence = stop[0]
|
end_sequence = stop[0]
|
||||||
else:
|
else:
|
||||||
end_sequence = None
|
end_sequence = None
|
||||||
response = self.client.generation(
|
params = {**self._default_params, **kwargs}
|
||||||
prompt, end_sequence=end_sequence, **self._default_params
|
response = self.client.generation(prompt, end_sequence=end_sequence, **params)
|
||||||
)
|
|
||||||
return response["generated_text"]
|
return response["generated_text"]
|
||||||
|
@ -273,6 +273,7 @@ class BaseOpenAI(BaseLLM):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Call out to OpenAI's endpoint with k unique prompts.
|
"""Call out to OpenAI's endpoint with k unique prompts.
|
||||||
|
|
||||||
@ -290,6 +291,7 @@ class BaseOpenAI(BaseLLM):
|
|||||||
"""
|
"""
|
||||||
# TODO: write a unit test for this
|
# TODO: write a unit test for this
|
||||||
params = self._invocation_params
|
params = self._invocation_params
|
||||||
|
params = {**params, **kwargs}
|
||||||
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
||||||
choices = []
|
choices = []
|
||||||
token_usage: Dict[str, int] = {}
|
token_usage: Dict[str, int] = {}
|
||||||
@ -326,9 +328,11 @@ class BaseOpenAI(BaseLLM):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Call out to OpenAI's endpoint async with k unique prompts."""
|
"""Call out to OpenAI's endpoint async with k unique prompts."""
|
||||||
params = self._invocation_params
|
params = self._invocation_params
|
||||||
|
params = {**params, **kwargs}
|
||||||
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
||||||
choices = []
|
choices = []
|
||||||
token_usage: Dict[str, int] = {}
|
token_usage: Dict[str, int] = {}
|
||||||
@ -771,8 +775,10 @@ class OpenAIChat(BaseLLM):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
messages, params = self._get_chat_params(prompts, stop)
|
messages, params = self._get_chat_params(prompts, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
response = ""
|
response = ""
|
||||||
params["stream"] = True
|
params["stream"] = True
|
||||||
@ -804,8 +810,10 @@ class OpenAIChat(BaseLLM):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
messages, params = self._get_chat_params(prompts, stop)
|
messages, params = self._get_chat_params(prompts, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
response = ""
|
response = ""
|
||||||
params["stream"] = True
|
params["stream"] = True
|
||||||
|
@ -137,9 +137,11 @@ class Petals(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call the Petals API."""
|
"""Call the Petals API."""
|
||||||
params = self._default_params
|
params = self._default_params
|
||||||
|
params = {**params, **kwargs}
|
||||||
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
|
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
|
||||||
outputs = self.client.generate(inputs, **params)
|
outputs = self.client.generate(inputs, **params)
|
||||||
text = self.tokenizer.decode(outputs[0])
|
text = self.tokenizer.decode(outputs[0])
|
||||||
|
@ -87,6 +87,7 @@ class PipelineAI(LLM, BaseModel):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call to Pipeline Cloud endpoint."""
|
"""Call to Pipeline Cloud endpoint."""
|
||||||
try:
|
try:
|
||||||
@ -98,6 +99,7 @@ class PipelineAI(LLM, BaseModel):
|
|||||||
)
|
)
|
||||||
client = PipelineCloud(token=self.pipeline_api_key)
|
client = PipelineCloud(token=self.pipeline_api_key)
|
||||||
params = self.pipeline_kwargs or {}
|
params = self.pipeline_kwargs or {}
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
|
||||||
run = client.run_pipeline(self.pipeline_key, [prompt, params])
|
run = client.run_pipeline(self.pipeline_key, [prompt, params])
|
||||||
try:
|
try:
|
||||||
|
@ -91,6 +91,7 @@ class PredictionGuard(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to Prediction Guard's model API.
|
"""Call out to Prediction Guard's model API.
|
||||||
Args:
|
Args:
|
||||||
@ -117,6 +118,7 @@ class PredictionGuard(LLM):
|
|||||||
output=self.output,
|
output=self.output,
|
||||||
temperature=params["temperature"],
|
temperature=params["temperature"],
|
||||||
max_tokens=params["max_tokens"],
|
max_tokens=params["max_tokens"],
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
text = response["choices"][0]["text"]
|
text = response["choices"][0]["text"]
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""PromptLayer wrapper."""
|
"""PromptLayer wrapper."""
|
||||||
import datetime
|
import datetime
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -42,6 +42,7 @@ class PromptLayerOpenAI(OpenAI):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Call OpenAI generate and then call PromptLayer API to log the request."""
|
"""Call OpenAI generate and then call PromptLayer API to log the request."""
|
||||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||||
@ -56,11 +57,12 @@ class PromptLayerOpenAI(OpenAI):
|
|||||||
"text": generation.text,
|
"text": generation.text,
|
||||||
"llm_output": generated_responses.llm_output,
|
"llm_output": generated_responses.llm_output,
|
||||||
}
|
}
|
||||||
|
params = {**self._identifying_params, **kwargs}
|
||||||
pl_request_id = promptlayer_api_request(
|
pl_request_id = promptlayer_api_request(
|
||||||
"langchain.PromptLayerOpenAI",
|
"langchain.PromptLayerOpenAI",
|
||||||
"langchain",
|
"langchain",
|
||||||
[prompt],
|
[prompt],
|
||||||
self._identifying_params,
|
params,
|
||||||
self.pl_tags,
|
self.pl_tags,
|
||||||
resp,
|
resp,
|
||||||
request_start_time,
|
request_start_time,
|
||||||
@ -81,6 +83,7 @@ class PromptLayerOpenAI(OpenAI):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
||||||
|
|
||||||
@ -94,11 +97,12 @@ class PromptLayerOpenAI(OpenAI):
|
|||||||
"text": generation.text,
|
"text": generation.text,
|
||||||
"llm_output": generated_responses.llm_output,
|
"llm_output": generated_responses.llm_output,
|
||||||
}
|
}
|
||||||
|
params = {**self._identifying_params, **kwargs}
|
||||||
pl_request_id = await promptlayer_api_request_async(
|
pl_request_id = await promptlayer_api_request_async(
|
||||||
"langchain.PromptLayerOpenAI.async",
|
"langchain.PromptLayerOpenAI.async",
|
||||||
"langchain",
|
"langchain",
|
||||||
[prompt],
|
[prompt],
|
||||||
self._identifying_params,
|
params,
|
||||||
self.pl_tags,
|
self.pl_tags,
|
||||||
resp,
|
resp,
|
||||||
request_start_time,
|
request_start_time,
|
||||||
@ -147,6 +151,7 @@ class PromptLayerOpenAIChat(OpenAIChat):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Call OpenAI generate and then call PromptLayer API to log the request."""
|
"""Call OpenAI generate and then call PromptLayer API to log the request."""
|
||||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||||
@ -161,11 +166,12 @@ class PromptLayerOpenAIChat(OpenAIChat):
|
|||||||
"text": generation.text,
|
"text": generation.text,
|
||||||
"llm_output": generated_responses.llm_output,
|
"llm_output": generated_responses.llm_output,
|
||||||
}
|
}
|
||||||
|
params = {**self._identifying_params, **kwargs}
|
||||||
pl_request_id = promptlayer_api_request(
|
pl_request_id = promptlayer_api_request(
|
||||||
"langchain.PromptLayerOpenAIChat",
|
"langchain.PromptLayerOpenAIChat",
|
||||||
"langchain",
|
"langchain",
|
||||||
[prompt],
|
[prompt],
|
||||||
self._identifying_params,
|
params,
|
||||||
self.pl_tags,
|
self.pl_tags,
|
||||||
resp,
|
resp,
|
||||||
request_start_time,
|
request_start_time,
|
||||||
@ -186,6 +192,7 @@ class PromptLayerOpenAIChat(OpenAIChat):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
||||||
|
|
||||||
@ -199,11 +206,12 @@ class PromptLayerOpenAIChat(OpenAIChat):
|
|||||||
"text": generation.text,
|
"text": generation.text,
|
||||||
"llm_output": generated_responses.llm_output,
|
"llm_output": generated_responses.llm_output,
|
||||||
}
|
}
|
||||||
|
params = {**self._identifying_params, **kwargs}
|
||||||
pl_request_id = await promptlayer_api_request_async(
|
pl_request_id = await promptlayer_api_request_async(
|
||||||
"langchain.PromptLayerOpenAIChat.async",
|
"langchain.PromptLayerOpenAIChat.async",
|
||||||
"langchain",
|
"langchain",
|
||||||
[prompt],
|
[prompt],
|
||||||
self._identifying_params,
|
params,
|
||||||
self.pl_tags,
|
self.pl_tags,
|
||||||
resp,
|
resp,
|
||||||
request_start_time,
|
request_start_time,
|
||||||
|
@ -85,6 +85,7 @@ class Replicate(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call to replicate endpoint."""
|
"""Call to replicate endpoint."""
|
||||||
try:
|
try:
|
||||||
@ -110,6 +111,6 @@ class Replicate(LLM):
|
|||||||
first_input_name = input_properties[0][0]
|
first_input_name = input_properties[0][0]
|
||||||
|
|
||||||
inputs = {first_input_name: prompt, **self.input}
|
inputs = {first_input_name: prompt, **self.input}
|
||||||
iterator = replicate_python.run(self.model, input={**inputs})
|
iterator = replicate_python.run(self.model, input={**inputs, **kwargs})
|
||||||
|
|
||||||
return "".join([output for output in iterator])
|
return "".join([output for output in iterator])
|
||||||
|
@ -210,6 +210,7 @@ class RWKV(LLM, BaseModel):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
r"""RWKV generation
|
r"""RWKV generation
|
||||||
|
|
||||||
|
@ -207,6 +207,7 @@ class SagemakerEndpoint(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to Sagemaker inference endpoint.
|
"""Call out to Sagemaker inference endpoint.
|
||||||
|
|
||||||
@ -223,6 +224,7 @@ class SagemakerEndpoint(LLM):
|
|||||||
response = se("Tell me a joke.")
|
response = se("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
_model_kwargs = {**_model_kwargs, **kwargs}
|
||||||
_endpoint_kwargs = self.endpoint_kwargs or {}
|
_endpoint_kwargs = self.endpoint_kwargs or {}
|
||||||
|
|
||||||
body = self.content_handler.transform_input(prompt, _model_kwargs)
|
body = self.content_handler.transform_input(prompt, _model_kwargs)
|
||||||
|
@ -214,5 +214,8 @@ class SelfHostedPipeline(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop)
|
return self.client(
|
||||||
|
pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs
|
||||||
|
)
|
||||||
|
@ -207,5 +207,8 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop)
|
return self.client(
|
||||||
|
pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs
|
||||||
|
)
|
||||||
|
@ -86,6 +86,7 @@ class StochasticAI(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to StochasticAI's complete endpoint.
|
"""Call out to StochasticAI's complete endpoint.
|
||||||
|
|
||||||
@ -102,6 +103,7 @@ class StochasticAI(LLM):
|
|||||||
response = StochasticAI("Tell me a joke.")
|
response = StochasticAI("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
params = self.model_kwargs or {}
|
params = self.model_kwargs or {}
|
||||||
|
params = {**params, **kwargs}
|
||||||
response_post = requests.post(
|
response_post = requests.post(
|
||||||
url=self.api_url,
|
url=self.api_url,
|
||||||
json={"prompt": prompt, "params": params},
|
json={"prompt": prompt, "params": params},
|
||||||
|
@ -50,8 +50,11 @@ class _VertexAICommon(BaseModel):
|
|||||||
}
|
}
|
||||||
return {**base_params}
|
return {**base_params}
|
||||||
|
|
||||||
def _predict(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
def _predict(
|
||||||
res = self.client.predict(prompt, **self._default_params)
|
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||||
|
) -> str:
|
||||||
|
params = {**self._default_params, **kwargs}
|
||||||
|
res = self.client.predict(prompt, **params)
|
||||||
return self._enforce_stop_words(res.text, stop)
|
return self._enforce_stop_words(res.text, stop)
|
||||||
|
|
||||||
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
||||||
@ -100,6 +103,7 @@ class VertexAI(_VertexAICommon, LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call Vertex model to get predictions based on the prompt.
|
"""Call Vertex model to get predictions based on the prompt.
|
||||||
|
|
||||||
@ -111,4 +115,4 @@ class VertexAI(_VertexAICommon, LLM):
|
|||||||
Returns:
|
Returns:
|
||||||
The string generated by the model.
|
The string generated by the model.
|
||||||
"""
|
"""
|
||||||
return self._predict(prompt, stop)
|
return self._predict(prompt, stop, **kwargs)
|
||||||
|
@ -118,6 +118,7 @@ class Writer(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call out to Writer's completions endpoint.
|
"""Call out to Writer's completions endpoint.
|
||||||
|
|
||||||
@ -141,7 +142,7 @@ class Writer(LLM):
|
|||||||
f"/organization/{self.writer_org_id}"
|
f"/organization/{self.writer_org_id}"
|
||||||
f"/model/{self.model_id}/completions"
|
f"/model/{self.model_id}/completions"
|
||||||
)
|
)
|
||||||
|
params = {**self._default_params, **kwargs}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=base_url,
|
url=base_url,
|
||||||
headers={
|
headers={
|
||||||
@ -149,7 +150,7 @@ class Writer(LLM):
|
|||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
},
|
},
|
||||||
json={"prompt": prompt, **self._default_params},
|
json={"prompt": prompt, **params},
|
||||||
)
|
)
|
||||||
text = response.text
|
text = response.text
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
|
@ -20,6 +20,7 @@ class FakeListLLM(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Increment counter, and then return response in that index."""
|
"""Increment counter, and then return response in that index."""
|
||||||
self.i += 1
|
self.i += 1
|
||||||
|
@ -38,6 +38,7 @@ class FakeListLLM(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Increment counter, and then return response in that index."""
|
"""Increment counter, and then return response in that index."""
|
||||||
self.i += 1
|
self.i += 1
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Test HyDE."""
|
"""Test HyDE."""
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -36,6 +36,7 @@ class FakeLLM(BaseLLM):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||||
|
|
||||||
@ -44,6 +45,7 @@ class FakeLLM(BaseLLM):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ class FakeLLM(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return `foo` if longer than 10000 words, else `bar`."""
|
"""Return `foo` if longer than 10000 words, else `bar`."""
|
||||||
if len(prompt) > 10000:
|
if len(prompt) > 10000:
|
||||||
|
@ -17,6 +17,7 @@ class FakeChatModel(SimpleChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
return "fake response"
|
return "fake response"
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ class FakeChatModel(SimpleChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
output_str = "fake response"
|
output_str = "fake response"
|
||||||
message = AIMessage(content=output_str)
|
message = AIMessage(content=output_str)
|
||||||
|
@ -34,6 +34,7 @@ class FakeLLM(LLM):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
if self.sequential_responses:
|
if self.sequential_responses:
|
||||||
return self._get_next_response_in_sequence
|
return self._get_next_response_in_sequence
|
||||||
|
Loading…
Reference in New Issue
Block a user