mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 18:53:10 +00:00
Add additional VertexAI Params (#5837)
## Changes - Added the `stop` param to the `_VertexAICommon` class so it can be set at llm initialization ## Example Usage ```python VertexAI( # ... temperature=0.15, max_output_tokens=128, top_p=1, top_k=40, stop=["\n```"], ) ``` ## Possible Reviewers - @hwchase17 - @agola11
This commit is contained in:
parent
76fcd96dae
commit
9f4b720a63
@ -29,6 +29,8 @@ class _VertexAICommon(BaseModel):
|
||||
top_k: int = 40
|
||||
"How the model selects tokens for output, the next token is selected from "
|
||||
"among the top-k most probable tokens."
|
||||
stop: Optional[List[str]] = None
|
||||
"Optional list of stop words to use when generating."
|
||||
project: Optional[str] = None
|
||||
"The default GCP project to use when making Vertex API calls."
|
||||
location: str = "us-central1"
|
||||
@ -48,11 +50,13 @@ class _VertexAICommon(BaseModel):
|
||||
}
|
||||
return {**base_params}
|
||||
|
||||
def _predict(self, prompt: str, stop: Optional[List[str]]) -> str:
|
||||
def _predict(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
res = self.client.predict(prompt, **self._default_params)
|
||||
return self._enforce_stop_words(res.text, stop)
|
||||
|
||||
def _enforce_stop_words(self, text: str, stop: Optional[List[str]]) -> str:
|
||||
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
||||
if stop is None and self.stop is not None:
|
||||
stop = self.stop
|
||||
if stop:
|
||||
return enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
Loading…
Reference in New Issue
Block a user