mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-23 16:08:10 +00:00
langchain[patch]: Support passing parameters to llms.Databricks
and llms.Mlflow
(#14100)
Before, we need to use `params` to pass extra parameters: ```python from langchain.llms import Databricks Databricks(..., params={"temperature": 0.0}) ``` Now, we can directly specify extra params: ```python from langchain.llms import Databricks Databricks(..., temperature=0.0) ```
This commit is contained in:
parent
82102c99b3
commit
41ee3be95f
@ -118,8 +118,10 @@ class ChatMlflow(BaseChatModel):
|
||||
"stop": stop or self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
**self.extra_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if stop := self.stop or stop:
|
||||
data["stop"] = stop
|
||||
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
|
||||
return ChatMlflow._create_chat_result(resp)
|
||||
|
||||
|
@ -46,6 +46,10 @@ class _DatabricksClientBase(BaseModel, ABC):
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
@property
|
||||
def llm(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _transform_completions(response: Dict[str, Any]) -> str:
|
||||
return response["choices"][0]["text"]
|
||||
@ -85,6 +89,10 @@ class _DatabricksServingEndpointClient(_DatabricksClientBase):
|
||||
)
|
||||
self.task = endpoint.get("task")
|
||||
|
||||
@property
|
||||
def llm(self) -> bool:
|
||||
return self.task in ("llm/v1/chat", "llm/v1/completions")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "api_url" not in values:
|
||||
@ -137,8 +145,11 @@ class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
|
||||
values["api_url"] = api_url
|
||||
return values
|
||||
|
||||
def post(self, request: Any, transform: Optional[Callable[..., str]] = None) -> Any:
|
||||
return self._post(self.api_url, request)
|
||||
def post(
|
||||
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
|
||||
) -> Any:
|
||||
resp = self._post(self.api_url, request)
|
||||
return transform_output_fn(resp) if transform_output_fn else resp
|
||||
|
||||
|
||||
def get_repl_context() -> Any:
|
||||
@ -285,12 +296,10 @@ class Databricks(LLM):
|
||||
We recommend the server using a port number between ``[3000, 8000]``.
|
||||
"""
|
||||
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
"""Extra parameters to pass to the endpoint."""
|
||||
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
Deprecated. Please use ``params`` instead. Extra parameters to pass to the endpoint.
|
||||
Deprecated. Please use ``extra_params`` instead. Extra parameters to pass to
|
||||
the endpoint.
|
||||
"""
|
||||
|
||||
transform_input_fn: Optional[Callable] = None
|
||||
@ -306,12 +315,34 @@ class Databricks(LLM):
|
||||
databricks_uri: str = "databricks"
|
||||
"""The databricks URI. Only used when using a serving endpoint."""
|
||||
|
||||
temperature: float = 0.0
|
||||
"""The sampling temperature."""
|
||||
n: int = 1
|
||||
"""The number of completion choices to generate."""
|
||||
stop: Optional[List[str]] = None
|
||||
"""The stop sequence."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
extra_params: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Any extra parameters to pass to the endpoint."""
|
||||
|
||||
_client: _DatabricksClientBase = PrivateAttr()
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
@property
|
||||
def _llm_params(self) -> Dict[str, Any]:
|
||||
params = {
|
||||
"temperature": self.temperature,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
**(self.model_kwargs or self.extra_params),
|
||||
}
|
||||
return params
|
||||
|
||||
@validator("cluster_id", always=True)
|
||||
def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
||||
if v and values["endpoint_name"]:
|
||||
@ -356,11 +387,11 @@ class Databricks(LLM):
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
if self.model_kwargs is not None and self.params is not None:
|
||||
raise ValueError("Cannot set both model_kwargs and params.")
|
||||
if self.model_kwargs is not None and self.extra_params is not None:
|
||||
raise ValueError("Cannot set both extra_params and extra_params.")
|
||||
elif self.model_kwargs is not None:
|
||||
warnings.warn(
|
||||
"model_kwargs is deprecated. Please use params instead.",
|
||||
"model_kwargs is deprecated. Please use extra_params instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if self.endpoint_name:
|
||||
@ -382,10 +413,6 @@ class Databricks(LLM):
|
||||
"Must specify either endpoint_name or cluster_id/cluster_driver_port."
|
||||
)
|
||||
|
||||
@property
|
||||
def _params(self) -> Optional[Dict[str, Any]]:
|
||||
return self.model_kwargs or self.params
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Return default params."""
|
||||
@ -397,7 +424,11 @@ class Databricks(LLM):
|
||||
"cluster_driver_port": self.cluster_driver_port,
|
||||
"databricks_uri": self.databricks_uri,
|
||||
"model_kwargs": self.model_kwargs,
|
||||
"params": self.params,
|
||||
"temperature": self.temperature,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"extra_params": self.extra_params,
|
||||
# TODO: Support saving transform_input_fn and transform_output_fn
|
||||
# "transform_input_fn": self.transform_input_fn,
|
||||
# "transform_output_fn": self.transform_output_fn,
|
||||
@ -423,17 +454,17 @@ class Databricks(LLM):
|
||||
|
||||
# TODO: support callbacks
|
||||
|
||||
request = {"prompt": prompt, "stop": stop}
|
||||
request: Dict[str, Any] = {"prompt": prompt}
|
||||
if self._client.llm:
|
||||
request.update(self._llm_params)
|
||||
request.update(self.model_kwargs or self.extra_params)
|
||||
else:
|
||||
request.update(self.model_kwargs or self.extra_params)
|
||||
request.update(kwargs)
|
||||
if self._params:
|
||||
request.update(self._params)
|
||||
if stop := self.stop or stop:
|
||||
request["stop"] = stop
|
||||
|
||||
if self.transform_input_fn:
|
||||
request = self.transform_input_fn(**request)
|
||||
|
||||
response = self._client.post(request)
|
||||
|
||||
if self.transform_output_fn:
|
||||
response = self.transform_output_fn(response)
|
||||
|
||||
return response
|
||||
return self._client.post(request, transform_output_fn=self.transform_output_fn)
|
||||
|
@ -5,7 +5,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import LLM
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, PrivateAttr
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, PrivateAttr
|
||||
|
||||
|
||||
# Ignoring type because below is valid pydantic code
|
||||
@ -41,7 +41,17 @@ class Mlflow(LLM):
|
||||
"""The endpoint to use."""
|
||||
target_uri: str
|
||||
"""The target URI to use."""
|
||||
params: Optional[Params] = None
|
||||
temperature: float = 0.0
|
||||
"""The sampling temperature."""
|
||||
n: int = 1
|
||||
"""The number of completion choices to generate."""
|
||||
stop: Optional[List[str]] = None
|
||||
"""The stop sequence."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
extra_params: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Any extra parameters to pass to the endpoint."""
|
||||
|
||||
"""Extra parameters such as `temperature`."""
|
||||
_client: Any = PrivateAttr()
|
||||
|
||||
@ -71,13 +81,15 @@ class Mlflow(LLM):
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
params: Dict[str, Any] = {
|
||||
return {
|
||||
"target_uri": self.target_uri,
|
||||
"endpoint": self.endpoint,
|
||||
"temperature": self.temperature,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"extra_params": self.extra_params,
|
||||
}
|
||||
if self.params:
|
||||
params["params"] = self.params.dict()
|
||||
return params
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
@ -92,10 +104,14 @@ class Mlflow(LLM):
|
||||
) -> str:
|
||||
data: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
"temperature": self.temperature,
|
||||
"n": self.n,
|
||||
"max_tokens": self.max_tokens,
|
||||
**self.extra_params,
|
||||
**kwargs,
|
||||
}
|
||||
if s := (stop or (self.params.stop if self.params else None)):
|
||||
data["stop"] = s
|
||||
if stop := self.stop or stop:
|
||||
data["stop"] = stop
|
||||
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
|
||||
return resp["choices"][0]["text"]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user