Compare commits

...

3 Commits

Author SHA1 Message Date
William Fu-Hinthorn
6ae65b94a5 lint 2023-11-02 19:19:15 -07:00
William Fu-Hinthorn
331915f9f0 Merge branch 'master' into wfh/ossinvoc 2023-11-02 19:17:34 -07:00
William Fu-Hinthorn
ab5bede903 tmp 2023-11-01 07:53:49 -07:00
2 changed files with 23 additions and 5 deletions

View File

@@ -216,6 +216,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
"Must be a PromptValue, str, or list of BaseMessages."
)
@property
def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
return self.dict()
def invoke(
self,
input: LanguageModelInput,
@@ -366,7 +371,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
else:
prompt = self._convert_input(input).to_string()
config = config or {}
params = self.dict()
params = self._invocation_params
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
@@ -417,7 +422,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
else:
prompt = self._convert_input(input).to_string()
config = config or {}
params = self.dict()
params = self._invocation_params
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
@@ -623,7 +628,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
params = self.dict()
params = self._invocation_params
params["stop"] = stop
options = {"stop": stop}
(
@@ -787,7 +792,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
params = self.dict()
params = self._invocation_params
params["stop"] = stop
options = {"stop": stop}
(

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import importlib.util
import logging
from typing import Any, List, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import BaseLLM
@@ -185,6 +185,19 @@ class HuggingFacePipeline(BaseLLM):
def _llm_type(self) -> str:
return "huggingface_pipeline"
@property
def _invocation_params(self) -> Dict[str, Any]:
params = super()._invocation_params
try:
params["model_config"] = self.pipeline.model.config
except NameError as e:
logger.warning(
"Unable to get model config in invocation params."
f" Received error:\n\n{e}"
)
pass
return params
def _generate(
self,
prompts: List[str],