mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-22 17:50:03 +00:00 
			
		
		
		
	**What is the reproduce code?**
```python
from langchain.chains import LLMChain, load_chain
from langchain.llms import Databricks
from langchain.prompts import PromptTemplate
def transform_output(response):
    # Extract the answer from the responses.
    return str(response["candidates"][0]["text"])
def transform_input(**request):
    full_prompt = f"""{request["prompt"]}
    Be Concise.
    """
    request["prompt"] = full_prompt
    return request
chat_model = Databricks(
    endpoint_name="llama2-13B-chat-Brambles",
    transform_input_fn=transform_input,
    transform_output_fn=transform_output,
    verbose=True,
)
print(f"Test chat model: {chat_model('What is Apache Spark')}") # This works
llm_chain = LLMChain(llm=chat_model, prompt=PromptTemplate.from_template("{chat_input}"))
llm_chain("colorful socks") # this works
llm_chain.save("databricks_llm_chain.yaml") # transform_input_fn and transform_output_fn are not serialized into the model yaml file
loaded_chain = load_chain("databricks_llm_chain.yaml") # The Databricks LLM is recreated with transform_input_fn=None, transform_output_fn=None.
loaded_chain("colorful socks") # Thus this errors. The transform_output_fn is needed to produce the correct output
```
Error:
```
 File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-6c34afab-3473-421d-877f-1ef18930ef4d/lib/python3.10/site-packages/pydantic/v1/main.py", line 341, in __init__
    raise validation_error
pydantic.v1.error_wrappers.ValidationError: 1 validation error for Generation
text
  str type expected (type=type_error.str)
 request payload: {'query': 'What is a databricks notebook?'}'}
```
**What does the error mean?**
When the LLM generates an answer, represented by a Generation data
object. The Generation data object takes a str field called text, e.g.
Generation(text=”blah”). However, the Databricks LLM tried to put a
non-str to text, e.g. Generation(text={“candidates”:[{“text”: “blah”}]})
Thus, pydantic errors.
**Why the output format becomes incorrect after saving and loading the
Databricks LLM?**
Databrick LLM does not support serializing transform_input_fn and
transform_output_fn, so they are not serialized into the model yaml
file. When the Databricks LLM is loaded, it is recreated with
transform_input_fn=None, transform_output_fn=None. Without
transform_output_fn, the output text is not unwrapped, thus errors.
Missing transform_output_fn causes this error.
Missing transform_input_fn causes the additional prompt “Be Concise.” to
be lost after saving and loading.
<!-- Thank you for contributing to LangChain!
Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!
Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.
See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://python.langchain.com/docs/contributing/
If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.
If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
---------
Co-authored-by: Bagatur <baskaryan@gmail.com>
		
	
		
			
				
	
	
		
			483 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			483 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import warnings
 | |
| from abc import ABC, abstractmethod
 | |
| from typing import Any, Callable, Dict, List, Mapping, Optional
 | |
| 
 | |
| import requests
 | |
| from langchain_core.callbacks import CallbackManagerForLLMRun
 | |
| from langchain_core.language_models import LLM
 | |
| from langchain_core.pydantic_v1 import (
 | |
|     BaseModel,
 | |
|     Extra,
 | |
|     Field,
 | |
|     PrivateAttr,
 | |
|     root_validator,
 | |
|     validator,
 | |
| )
 | |
| 
 | |
| __all__ = ["Databricks"]
 | |
| 
 | |
| 
 | |
| class _DatabricksClientBase(BaseModel, ABC):
 | |
|     """A base JSON API client that talks to Databricks."""
 | |
| 
 | |
|     api_url: str
 | |
|     api_token: str
 | |
| 
 | |
|     def request(self, method: str, url: str, request: Any) -> Any:
 | |
|         headers = {"Authorization": f"Bearer {self.api_token}"}
 | |
|         response = requests.request(
 | |
|             method=method, url=url, headers=headers, json=request
 | |
|         )
 | |
|         # TODO: error handling and automatic retries
 | |
|         if not response.ok:
 | |
|             raise ValueError(f"HTTP {response.status_code} error: {response.text}")
 | |
|         return response.json()
 | |
| 
 | |
|     def _get(self, url: str) -> Any:
 | |
|         return self.request("GET", url, None)
 | |
| 
 | |
|     def _post(self, url: str, request: Any) -> Any:
 | |
|         return self.request("POST", url, request)
 | |
| 
 | |
|     @abstractmethod
 | |
|     def post(
 | |
|         self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
 | |
|     ) -> Any:
 | |
|         ...
 | |
| 
 | |
|     @property
 | |
|     def llm(self) -> bool:
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def _transform_completions(response: Dict[str, Any]) -> str:
 | |
|     return response["choices"][0]["text"]
 | |
| 
 | |
| 
 | |
| def _transform_llama2_chat(response: Dict[str, Any]) -> str:
 | |
|     return response["candidates"][0]["text"]
 | |
| 
 | |
| 
 | |
| def _transform_chat(response: Dict[str, Any]) -> str:
 | |
|     return response["choices"][0]["message"]["content"]
 | |
| 
 | |
| 
 | |
| class _DatabricksServingEndpointClient(_DatabricksClientBase):
 | |
|     """An API client that talks to a Databricks serving endpoint."""
 | |
| 
 | |
|     host: str
 | |
|     endpoint_name: str
 | |
|     databricks_uri: str
 | |
|     client: Any = None
 | |
|     external_or_foundation: bool = False
 | |
|     task: Optional[str] = None
 | |
| 
 | |
|     def __init__(self, **data: Any):
 | |
|         super().__init__(**data)
 | |
| 
 | |
|         try:
 | |
|             from mlflow.deployments import get_deploy_client
 | |
| 
 | |
|             self.client = get_deploy_client(self.databricks_uri)
 | |
|         except ImportError as e:
 | |
|             raise ImportError(
 | |
|                 "Failed to create the client. "
 | |
|                 "Please install mlflow with `pip install mlflow`."
 | |
|             ) from e
 | |
| 
 | |
|         endpoint = self.client.get_endpoint(self.endpoint_name)
 | |
|         self.external_or_foundation = endpoint.get("endpoint_type", "").lower() in (
 | |
|             "external_model",
 | |
|             "foundation_model_api",
 | |
|         )
 | |
|         if self.task is None:
 | |
|             self.task = endpoint.get("task")
 | |
| 
 | |
|     @property
 | |
|     def llm(self) -> bool:
 | |
|         return self.task in ("llm/v1/chat", "llm/v1/completions", "llama2/chat")
 | |
| 
 | |
|     @root_validator(pre=True)
 | |
|     def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
 | |
|         if "api_url" not in values:
 | |
|             host = values["host"]
 | |
|             endpoint_name = values["endpoint_name"]
 | |
|             api_url = f"https://{host}/serving-endpoints/{endpoint_name}/invocations"
 | |
|             values["api_url"] = api_url
 | |
|         return values
 | |
| 
 | |
|     def post(
 | |
|         self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
 | |
|     ) -> Any:
 | |
|         if self.external_or_foundation:
 | |
|             resp = self.client.predict(endpoint=self.endpoint_name, inputs=request)
 | |
|             if transform_output_fn:
 | |
|                 return transform_output_fn(resp)
 | |
| 
 | |
|             if self.task == "llm/v1/chat":
 | |
|                 return _transform_chat(resp)
 | |
|             elif self.task == "llm/v1/completions":
 | |
|                 return _transform_completions(resp)
 | |
| 
 | |
|             return resp
 | |
|         else:
 | |
|             # See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
 | |
|             wrapped_request = {"dataframe_records": [request]}
 | |
|             response = self.client.predict(
 | |
|                 endpoint=self.endpoint_name, inputs=wrapped_request
 | |
|             )
 | |
|             preds = response["predictions"]
 | |
|             # For a single-record query, the result is not a list.
 | |
|             pred = preds[0] if isinstance(preds, list) else preds
 | |
|             if self.task == "llama2/chat":
 | |
|                 return _transform_llama2_chat(pred)
 | |
|             return transform_output_fn(pred) if transform_output_fn else pred
 | |
| 
 | |
| 
 | |
| class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
 | |
|     """An API client that talks to a Databricks cluster driver proxy app."""
 | |
| 
 | |
|     host: str
 | |
|     cluster_id: str
 | |
|     cluster_driver_port: str
 | |
| 
 | |
|     @root_validator(pre=True)
 | |
|     def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
 | |
|         if "api_url" not in values:
 | |
|             host = values["host"]
 | |
|             cluster_id = values["cluster_id"]
 | |
|             port = values["cluster_driver_port"]
 | |
|             api_url = f"https://{host}/driver-proxy-api/o/0/{cluster_id}/{port}"
 | |
|             values["api_url"] = api_url
 | |
|         return values
 | |
| 
 | |
|     def post(
 | |
|         self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
 | |
|     ) -> Any:
 | |
|         resp = self._post(self.api_url, request)
 | |
|         return transform_output_fn(resp) if transform_output_fn else resp
 | |
| 
 | |
| 
 | |
| def get_repl_context() -> Any:
 | |
|     """Gets the notebook REPL context if running inside a Databricks notebook.
 | |
|     Returns None otherwise.
 | |
|     """
 | |
|     try:
 | |
|         from dbruntime.databricks_repl_context import get_context
 | |
| 
 | |
|         return get_context()
 | |
|     except ImportError:
 | |
|         raise ImportError(
 | |
|             "Cannot access dbruntime, not running inside a Databricks notebook."
 | |
|         )
 | |
| 
 | |
| 
 | |
| def get_default_host() -> str:
 | |
|     """Gets the default Databricks workspace hostname.
 | |
|     Raises an error if the hostname cannot be automatically determined.
 | |
|     """
 | |
|     host = os.getenv("DATABRICKS_HOST")
 | |
|     if not host:
 | |
|         try:
 | |
|             host = get_repl_context().browserHostName
 | |
|             if not host:
 | |
|                 raise ValueError("context doesn't contain browserHostName.")
 | |
|         except Exception as e:
 | |
|             raise ValueError(
 | |
|                 "host was not set and cannot be automatically inferred. Set "
 | |
|                 f"environment variable 'DATABRICKS_HOST'. Received error: {e}"
 | |
|             )
 | |
|     # TODO: support Databricks CLI profile
 | |
|     host = host.lstrip("https://").lstrip("http://").rstrip("/")
 | |
|     return host
 | |
| 
 | |
| 
 | |
| def get_default_api_token() -> str:
 | |
|     """Gets the default Databricks personal access token.
 | |
|     Raises an error if the token cannot be automatically determined.
 | |
|     """
 | |
|     if api_token := os.getenv("DATABRICKS_TOKEN"):
 | |
|         return api_token
 | |
|     try:
 | |
|         api_token = get_repl_context().apiToken
 | |
|         if not api_token:
 | |
|             raise ValueError("context doesn't contain apiToken.")
 | |
|     except Exception as e:
 | |
|         raise ValueError(
 | |
|             "api_token was not set and cannot be automatically inferred. Set "
 | |
|             f"environment variable 'DATABRICKS_TOKEN'. Received error: {e}"
 | |
|         )
 | |
|     # TODO: support Databricks CLI profile
 | |
|     return api_token
 | |
| 
 | |
| 
 | |
| class Databricks(LLM):
 | |
| 
 | |
|     """Databricks serving endpoint or a cluster driver proxy app for LLM.
 | |
| 
 | |
|     It supports two endpoint types:
 | |
| 
 | |
|     * **Serving endpoint** (recommended for both production and development).
 | |
|       We assume that an LLM was deployed to a serving endpoint.
 | |
|       To wrap it as an LLM you must have "Can Query" permission to the endpoint.
 | |
|       Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and
 | |
|       ``cluster_driver_port``.
 | |
| 
 | |
|       If the underlying model is a model registered by MLflow, the expected model
 | |
|       signature is:
 | |
| 
 | |
|       * inputs::
 | |
| 
 | |
|           [{"name": "prompt", "type": "string"},
 | |
|            {"name": "stop", "type": "list[string]"}]
 | |
| 
 | |
|       * outputs: ``[{"type": "string"}]``
 | |
| 
 | |
|       If the underlying model is an external or foundation model, the response from the
 | |
|       endpoint is automatically transformed to the expected format unless
 | |
|       ``transform_output_fn`` is provided.
 | |
| 
 | |
|     * **Cluster driver proxy app** (recommended for interactive development).
 | |
|       One can load an LLM on a Databricks interactive cluster and start a local HTTP
 | |
|       server on the driver node to serve the model at ``/`` using HTTP POST method
 | |
|       with JSON input/output.
 | |
|       Please use a port number between ``[3000, 8000]`` and let the server listen to
 | |
|       the driver IP address or simply ``0.0.0.0`` instead of localhost only.
 | |
|       To wrap it as an LLM you must have "Can Attach To" permission to the cluster.
 | |
|       Set ``cluster_id`` and ``cluster_driver_port`` and do not set ``endpoint_name``.
 | |
|       The expected server schema (using JSON schema) is:
 | |
| 
 | |
|       * inputs::
 | |
| 
 | |
|           {"type": "object",
 | |
|            "properties": {
 | |
|               "prompt": {"type": "string"},
 | |
|               "stop": {"type": "array", "items": {"type": "string"}}},
 | |
|            "required": ["prompt"]}`
 | |
| 
 | |
|       * outputs: ``{"type": "string"}``
 | |
| 
 | |
|     If the endpoint model signature is different or you want to set extra params,
 | |
|     you can use `transform_input_fn` and `transform_output_fn` to apply necessary
 | |
|     transformations before and after the query.
 | |
|     """
 | |
| 
 | |
|     host: str = Field(default_factory=get_default_host)
 | |
|     """Databricks workspace hostname.
 | |
|     If not provided, the default value is determined by
 | |
| 
 | |
|     * the ``DATABRICKS_HOST`` environment variable if present, or
 | |
|     * the hostname of the current Databricks workspace if running inside
 | |
|       a Databricks notebook attached to an interactive cluster in "single user"
 | |
|       or "no isolation shared" mode.
 | |
|     """
 | |
| 
 | |
|     api_token: str = Field(default_factory=get_default_api_token)
 | |
|     """Databricks personal access token.
 | |
|     If not provided, the default value is determined by
 | |
| 
 | |
|     * the ``DATABRICKS_TOKEN`` environment variable if present, or
 | |
|     * an automatically generated temporary token if running inside a Databricks
 | |
|       notebook attached to an interactive cluster in "single user" or
 | |
|       "no isolation shared" mode.
 | |
|     """
 | |
| 
 | |
|     endpoint_name: Optional[str] = None
 | |
|     """Name of the model serving endpoint.
 | |
|     You must specify the endpoint name to connect to a model serving endpoint.
 | |
|     You must not set both ``endpoint_name`` and ``cluster_id``.
 | |
|     """
 | |
| 
 | |
|     cluster_id: Optional[str] = None
 | |
|     """ID of the cluster if connecting to a cluster driver proxy app.
 | |
|     If neither ``endpoint_name`` nor ``cluster_id`` is not provided and the code runs
 | |
|     inside a Databricks notebook attached to an interactive cluster in "single user"
 | |
|     or "no isolation shared" mode, the current cluster ID is used as default.
 | |
|     You must not set both ``endpoint_name`` and ``cluster_id``.
 | |
|     """
 | |
| 
 | |
|     cluster_driver_port: Optional[str] = None
 | |
|     """The port number used by the HTTP server running on the cluster driver node.
 | |
|     The server should listen on the driver IP address or simply ``0.0.0.0`` to connect.
 | |
|     We recommend the server using a port number between ``[3000, 8000]``.
 | |
|     """
 | |
| 
 | |
|     model_kwargs: Optional[Dict[str, Any]] = None
 | |
|     """
 | |
|     Deprecated. Please use ``extra_params`` instead. Extra parameters to pass to
 | |
|     the endpoint.
 | |
|     """
 | |
| 
 | |
|     transform_input_fn: Optional[Callable] = None
 | |
|     """A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible
 | |
|     request object that the endpoint accepts.
 | |
|     For example, you can apply a prompt template to the input prompt.
 | |
|     """
 | |
| 
 | |
|     transform_output_fn: Optional[Callable[..., str]] = None
 | |
|     """A function that transforms the output from the endpoint to the generated text.
 | |
|     """
 | |
| 
 | |
|     databricks_uri: str = "databricks"
 | |
|     """The databricks URI. Only used when using a serving endpoint."""
 | |
| 
 | |
|     temperature: float = 0.0
 | |
|     """The sampling temperature."""
 | |
|     n: int = 1
 | |
|     """The number of completion choices to generate."""
 | |
|     stop: Optional[List[str]] = None
 | |
|     """The stop sequence."""
 | |
|     max_tokens: Optional[int] = None
 | |
|     """The maximum number of tokens to generate."""
 | |
|     extra_params: Dict[str, Any] = Field(default_factory=dict)
 | |
|     """Any extra parameters to pass to the endpoint."""
 | |
|     task: Optional[str] = None
 | |
|     """The task of the endpoint. Only used when using a serving endpoint.
 | |
|     If not provided, the task is automatically inferred from the endpoint.
 | |
|     """
 | |
| 
 | |
|     _client: _DatabricksClientBase = PrivateAttr()
 | |
| 
 | |
|     class Config:
 | |
|         extra = Extra.forbid
 | |
|         underscore_attrs_are_private = True
 | |
| 
 | |
|     @property
 | |
|     def _llm_params(self) -> Dict[str, Any]:
 | |
|         params: Dict[str, Any] = {
 | |
|             "temperature": self.temperature,
 | |
|             "n": self.n,
 | |
|         }
 | |
|         if self.stop:
 | |
|             params["stop"] = self.stop
 | |
|         if self.max_tokens is not None:
 | |
|             params["max_tokens"] = self.max_tokens
 | |
|         return params
 | |
| 
 | |
|     @validator("cluster_id", always=True)
 | |
|     def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
 | |
|         if v and values["endpoint_name"]:
 | |
|             raise ValueError("Cannot set both endpoint_name and cluster_id.")
 | |
|         elif values["endpoint_name"]:
 | |
|             return None
 | |
|         elif v:
 | |
|             return v
 | |
|         else:
 | |
|             try:
 | |
|                 if v := get_repl_context().clusterId:
 | |
|                     return v
 | |
|                 raise ValueError("Context doesn't contain clusterId.")
 | |
|             except Exception as e:
 | |
|                 raise ValueError(
 | |
|                     "Neither endpoint_name nor cluster_id was set. "
 | |
|                     "And the cluster_id cannot be automatically determined. Received"
 | |
|                     f" error: {e}"
 | |
|                 )
 | |
| 
 | |
|     @validator("cluster_driver_port", always=True)
 | |
|     def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
 | |
|         if v and values["endpoint_name"]:
 | |
|             raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
 | |
|         elif values["endpoint_name"]:
 | |
|             return None
 | |
|         elif v is None:
 | |
|             raise ValueError(
 | |
|                 "Must set cluster_driver_port to connect to a cluster driver."
 | |
|             )
 | |
|         elif int(v) <= 0:
 | |
|             raise ValueError(f"Invalid cluster_driver_port: {v}")
 | |
|         else:
 | |
|             return v
 | |
| 
 | |
|     @validator("model_kwargs", always=True)
 | |
|     def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
 | |
|         if v:
 | |
|             assert "prompt" not in v, "model_kwargs must not contain key 'prompt'"
 | |
|             assert "stop" not in v, "model_kwargs must not contain key 'stop'"
 | |
|         return v
 | |
| 
 | |
|     def __init__(self, **data: Any):
 | |
|         super().__init__(**data)
 | |
|         if self.model_kwargs is not None and self.extra_params is not None:
 | |
|             raise ValueError("Cannot set both extra_params and extra_params.")
 | |
|         elif self.model_kwargs is not None:
 | |
|             warnings.warn(
 | |
|                 "model_kwargs is deprecated. Please use extra_params instead.",
 | |
|                 DeprecationWarning,
 | |
|             )
 | |
|         if self.endpoint_name:
 | |
|             self._client = _DatabricksServingEndpointClient(
 | |
|                 host=self.host,
 | |
|                 api_token=self.api_token,
 | |
|                 endpoint_name=self.endpoint_name,
 | |
|                 databricks_uri=self.databricks_uri,
 | |
|                 task=self.task,
 | |
|             )
 | |
|         elif self.cluster_id and self.cluster_driver_port:
 | |
|             self._client = _DatabricksClusterDriverProxyClient(
 | |
|                 host=self.host,
 | |
|                 api_token=self.api_token,
 | |
|                 cluster_id=self.cluster_id,
 | |
|                 cluster_driver_port=self.cluster_driver_port,
 | |
|             )
 | |
|         else:
 | |
|             raise ValueError(
 | |
|                 "Must specify either endpoint_name or cluster_id/cluster_driver_port."
 | |
|             )
 | |
| 
 | |
|     @property
 | |
|     def _default_params(self) -> Dict[str, Any]:
 | |
|         """Return default params."""
 | |
|         return {
 | |
|             "host": self.host,
 | |
|             # "api_token": self.api_token,  # Never save the token
 | |
|             "endpoint_name": self.endpoint_name,
 | |
|             "cluster_id": self.cluster_id,
 | |
|             "cluster_driver_port": self.cluster_driver_port,
 | |
|             "databricks_uri": self.databricks_uri,
 | |
|             "model_kwargs": self.model_kwargs,
 | |
|             "temperature": self.temperature,
 | |
|             "n": self.n,
 | |
|             "stop": self.stop,
 | |
|             "max_tokens": self.max_tokens,
 | |
|             "extra_params": self.extra_params,
 | |
|             "task": self.task,
 | |
|             # TODO: Support saving transform_input_fn and transform_output_fn
 | |
|             # "transform_input_fn": self.transform_input_fn,
 | |
|             # "transform_output_fn": self.transform_output_fn,
 | |
|         }
 | |
| 
 | |
|     @property
 | |
|     def _identifying_params(self) -> Mapping[str, Any]:
 | |
|         return self._default_params
 | |
| 
 | |
|     @property
 | |
|     def _llm_type(self) -> str:
 | |
|         """Return type of llm."""
 | |
|         return "databricks"
 | |
| 
 | |
|     def _call(
 | |
|         self,
 | |
|         prompt: str,
 | |
|         stop: Optional[List[str]] = None,
 | |
|         run_manager: Optional[CallbackManagerForLLMRun] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> str:
 | |
|         """Queries the LLM endpoint with the given prompt and stop sequence."""
 | |
| 
 | |
|         # TODO: support callbacks
 | |
| 
 | |
|         request: Dict[str, Any] = {"prompt": prompt}
 | |
|         if self._client.llm:
 | |
|             request.update(self._llm_params)
 | |
|         request.update(self.model_kwargs or self.extra_params)
 | |
|         request.update(kwargs)
 | |
|         if stop:
 | |
|             request["stop"] = stop
 | |
| 
 | |
|         if self.transform_input_fn:
 | |
|             request = self.transform_input_fn(**request)
 | |
| 
 | |
|         return self._client.post(request, transform_output_fn=self.transform_output_fn)
 |