mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 23:29:21 +00:00
community[patch]: OpenLLM Client Fixes + Added Timeout Parameter (#17478)
- OpenLLM was using outdated method to get the final text output from openllm client invocation which was raising the error. Therefore corrected that. - OpenLLM `_identifying_params` was getting the openllm's client configuration using outdated attributes which was raising error. - Updated the docstring for OpenLLM. - Added timeout parameter to be passed to underlying openllm client.
This commit is contained in:
parent
1d2aa19aee
commit
43dc5d3416
@ -72,7 +72,7 @@ class OpenLLM(LLM):
|
|||||||
|
|
||||||
from langchain_community.llms import OpenLLM
|
from langchain_community.llms import OpenLLM
|
||||||
llm = OpenLLM(server_url='http://localhost:3000')
|
llm = OpenLLM(server_url='http://localhost:3000')
|
||||||
llm("What is the difference between a duck and a goose?")
|
llm.invoke("What is the difference between a duck and a goose?")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_name: Optional[str] = None
|
model_name: Optional[str] = None
|
||||||
@ -82,6 +82,8 @@ class OpenLLM(LLM):
|
|||||||
See 'openllm models' for all available model variants."""
|
See 'openllm models' for all available model variants."""
|
||||||
server_url: Optional[str] = None
|
server_url: Optional[str] = None
|
||||||
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
|
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
|
||||||
|
timeout: int = 30
|
||||||
|
""""Time out for the openllm client"""
|
||||||
server_type: ServerType = "http"
|
server_type: ServerType = "http"
|
||||||
"""Optional server type. Either 'http' or 'grpc'."""
|
"""Optional server type. Either 'http' or 'grpc'."""
|
||||||
embedded: bool = True
|
embedded: bool = True
|
||||||
@ -125,6 +127,7 @@ class OpenLLM(LLM):
|
|||||||
*,
|
*,
|
||||||
model_id: Optional[str] = None,
|
model_id: Optional[str] = None,
|
||||||
server_url: Optional[str] = None,
|
server_url: Optional[str] = None,
|
||||||
|
timeout: int = 30,
|
||||||
server_type: Literal["grpc", "http"] = "http",
|
server_type: Literal["grpc", "http"] = "http",
|
||||||
embedded: bool = True,
|
embedded: bool = True,
|
||||||
**llm_kwargs: Any,
|
**llm_kwargs: Any,
|
||||||
@ -149,11 +152,12 @@ class OpenLLM(LLM):
|
|||||||
if server_type == "http"
|
if server_type == "http"
|
||||||
else openllm.client.GrpcClient
|
else openllm.client.GrpcClient
|
||||||
)
|
)
|
||||||
client = client_cls(server_url)
|
client = client_cls(server_url, timeout)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
**{
|
**{
|
||||||
"server_url": server_url,
|
"server_url": server_url,
|
||||||
|
"timeout": timeout,
|
||||||
"server_type": server_type,
|
"server_type": server_type,
|
||||||
"llm_kwargs": llm_kwargs,
|
"llm_kwargs": llm_kwargs,
|
||||||
}
|
}
|
||||||
@ -217,9 +221,9 @@ class OpenLLM(LLM):
|
|||||||
def _identifying_params(self) -> IdentifyingParams:
|
def _identifying_params(self) -> IdentifyingParams:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
if self._client is not None:
|
if self._client is not None:
|
||||||
self.llm_kwargs.update(self._client._config())
|
self.llm_kwargs.update(self._client._config)
|
||||||
model_name = self._client._metadata()["model_name"]
|
model_name = self._client._metadata.model_dump()["model_name"]
|
||||||
model_id = self._client._metadata()["model_id"]
|
model_id = self._client._metadata.model_dump()["model_id"]
|
||||||
else:
|
else:
|
||||||
if self._runner is None:
|
if self._runner is None:
|
||||||
raise ValueError("Runner must be initialized.")
|
raise ValueError("Runner must be initialized.")
|
||||||
@ -265,9 +269,11 @@ class OpenLLM(LLM):
|
|||||||
self._identifying_params["model_name"], **copied
|
self._identifying_params["model_name"], **copied
|
||||||
)
|
)
|
||||||
if self._client:
|
if self._client:
|
||||||
res = self._client.generate(
|
res = (
|
||||||
prompt, **config.model_dump(flatten=True)
|
self._client.generate(prompt, **config.model_dump(flatten=True))
|
||||||
).responses[0]
|
.outputs[0]
|
||||||
|
.text
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert self._runner is not None
|
assert self._runner is not None
|
||||||
res = self._runner(prompt, **config.model_dump(flatten=True))
|
res = self._runner(prompt, **config.model_dump(flatten=True))
|
||||||
|
Loading…
Reference in New Issue
Block a user