mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-10 15:33:11 +00:00
add kwargs to llm runnables (#8388)
This commit is contained in:
@@ -101,13 +101,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessageChunk:
|
||||||
return cast(
|
return cast(
|
||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
cast(
|
cast(
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
self.generate_prompt(
|
self.generate_prompt(
|
||||||
[self._convert_input(input)], stop=stop, **(config or {})
|
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||||
).generations[0][0],
|
).generations[0][0],
|
||||||
).message,
|
).message,
|
||||||
)
|
)
|
||||||
@@ -118,15 +119,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessageChunk:
|
||||||
if type(self)._agenerate == BaseChatModel._agenerate:
|
if type(self)._agenerate == BaseChatModel._agenerate:
|
||||||
# model doesn't implement async generation, so use default implementation
|
# model doesn't implement async generation, so use default implementation
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
None, partial(self.invoke, input, config, stop=stop)
|
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
[self._convert_input(input)], stop=stop, **(config or {})
|
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||||
)
|
)
|
||||||
return cast(
|
return cast(
|
||||||
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||||
|
@@ -213,10 +213,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
return (
|
return (
|
||||||
self.generate_prompt(
|
self.generate_prompt(
|
||||||
[self._convert_input(input)], stop=stop, **(config or {})
|
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||||
)
|
)
|
||||||
.generations[0][0]
|
.generations[0][0]
|
||||||
.text
|
.text
|
||||||
@@ -228,15 +229,16 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
if type(self)._agenerate == BaseLLM._agenerate:
|
if type(self)._agenerate == BaseLLM._agenerate:
|
||||||
# model doesn't implement async invoke, so use default implementation
|
# model doesn't implement async invoke, so use default implementation
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
None, partial(self.invoke, input, config, stop=stop)
|
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
[self._convert_input(input)], stop=stop, **(config or {})
|
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||||
)
|
)
|
||||||
return llm_result.generations[0][0].text
|
return llm_result.generations[0][0].text
|
||||||
|
|
||||||
@@ -245,6 +247,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
inputs: List[LanguageModelInput],
|
inputs: List[LanguageModelInput],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
max_concurrency: Optional[int] = None,
|
max_concurrency: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
config = self._get_config_list(config, len(inputs))
|
config = self._get_config_list(config, len(inputs))
|
||||||
|
|
||||||
@@ -254,6 +257,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
callbacks=[c.get("callbacks") for c in config],
|
callbacks=[c.get("callbacks") for c in config],
|
||||||
tags=[c.get("tags") for c in config],
|
tags=[c.get("tags") for c in config],
|
||||||
metadata=[c.get("metadata") for c in config],
|
metadata=[c.get("metadata") for c in config],
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return [g[0].text for g in llm_result.generations]
|
return [g[0].text for g in llm_result.generations]
|
||||||
else:
|
else:
|
||||||
@@ -264,7 +268,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
return [
|
return [
|
||||||
output
|
output
|
||||||
for batch in batches
|
for batch in batches
|
||||||
for output in self.batch(batch, config=config)
|
for output in self.batch(batch, config=config, **kwargs)
|
||||||
]
|
]
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
@@ -272,6 +276,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
inputs: List[LanguageModelInput],
|
inputs: List[LanguageModelInput],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
max_concurrency: Optional[int] = None,
|
max_concurrency: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
if type(self)._agenerate == BaseLLM._agenerate:
|
if type(self)._agenerate == BaseLLM._agenerate:
|
||||||
# model doesn't implement async batch, so use default implementation
|
# model doesn't implement async batch, so use default implementation
|
||||||
@@ -287,6 +292,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
callbacks=[c.get("callbacks") for c in config],
|
callbacks=[c.get("callbacks") for c in config],
|
||||||
tags=[c.get("tags") for c in config],
|
tags=[c.get("tags") for c in config],
|
||||||
metadata=[c.get("metadata") for c in config],
|
metadata=[c.get("metadata") for c in config],
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return [g[0].text for g in llm_result.generations]
|
return [g[0].text for g in llm_result.generations]
|
||||||
else:
|
else:
|
||||||
@@ -297,7 +303,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
return [
|
return [
|
||||||
output
|
output
|
||||||
for batch in batches
|
for batch in batches
|
||||||
for output in await self.abatch(batch, config=config)
|
for output in await self.abatch(batch, config=config, **kwargs)
|
||||||
]
|
]
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
@@ -306,15 +312,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Iterator[str]:
|
) -> Iterator[str]:
|
||||||
if type(self)._stream == BaseLLM._stream:
|
if type(self)._stream == BaseLLM._stream:
|
||||||
# model doesn't implement streaming, so use default implementation
|
# model doesn't implement streaming, so use default implementation
|
||||||
yield self.invoke(input, config=config, stop=stop)
|
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||||
else:
|
else:
|
||||||
prompt = self._convert_input(input).to_string()
|
prompt = self._convert_input(input).to_string()
|
||||||
config = config or {}
|
config = config or {}
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
|
params = {**params, **kwargs}
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
config.get("callbacks"),
|
config.get("callbacks"),
|
||||||
@@ -330,7 +338,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
generation: Optional[GenerationChunk] = None
|
generation: Optional[GenerationChunk] = None
|
||||||
for chunk in self._stream(prompt, stop=stop, run_manager=run_manager):
|
for chunk in self._stream(
|
||||||
|
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
yield chunk.text
|
yield chunk.text
|
||||||
if generation is None:
|
if generation is None:
|
||||||
generation = chunk
|
generation = chunk
|
||||||
@@ -349,15 +359,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
if type(self)._astream == BaseLLM._astream:
|
if type(self)._astream == BaseLLM._astream:
|
||||||
# model doesn't implement streaming, so use default implementation
|
# model doesn't implement streaming, so use default implementation
|
||||||
yield await self.ainvoke(input, config=config, stop=stop)
|
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
||||||
else:
|
else:
|
||||||
prompt = self._convert_input(input).to_string()
|
prompt = self._convert_input(input).to_string()
|
||||||
config = config or {}
|
config = config or {}
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
|
params = {**params, **kwargs}
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
config.get("callbacks"),
|
config.get("callbacks"),
|
||||||
@@ -374,7 +386,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
try:
|
try:
|
||||||
generation: Optional[GenerationChunk] = None
|
generation: Optional[GenerationChunk] = None
|
||||||
async for chunk in self._astream(
|
async for chunk in self._astream(
|
||||||
prompt, stop=stop, run_manager=run_manager
|
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
):
|
):
|
||||||
yield chunk.text
|
yield chunk.text
|
||||||
if generation is None:
|
if generation is None:
|
||||||
|
Reference in New Issue
Block a user