mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 15:04:13 +00:00
multiple: implement ls_params (#22621)
implement ls_params for ai21, fireworks, groq.
This commit is contained in:
@@ -31,6 +31,7 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
LangSmithParams,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
@@ -363,6 +364,23 @@ class ChatFireworks(BaseChatModel):
|
||||
params["max_tokens"] = self.max_tokens
|
||||
return params
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
ls_params = LangSmithParams(
|
||||
ls_provider="fireworks",
|
||||
ls_model_name=self.model_name,
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
|
||||
ls_params["ls_max_tokens"] = ls_max_tokens
|
||||
if ls_stop := stop or params.get("stop", None):
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
return ls_params
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
system_fingerprint = None
|
||||
|
Reference in New Issue
Block a user