multiple: implement ls_params (#22621)

implement ls_params for ai21, fireworks, groq.
This commit is contained in:
ccurme
2024-06-06 12:51:37 -04:00
committed by GitHub
parent f26ab93df8
commit b57aa89f34
10 changed files with 70 additions and 61 deletions

View File

@@ -31,6 +31,7 @@ from langchain_core.callbacks import (
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
agenerate_from_stream,
generate_from_stream,
)
@@ -363,6 +364,23 @@ class ChatFireworks(BaseChatModel):
params["max_tokens"] = self.max_tokens
return params
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="fireworks",
ls_model_name=self.model_name,
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None):
ls_params["ls_stop"] = ls_stop
return ls_params
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
system_fingerprint = None