community: sambastudio chat model integration minor fix (#27238)

**Description:** sambastudio chat model integration minor fix
 fix default params
 fix usage metadata when streaming
This commit is contained in:
Jorge Piedrahita Ortiz 2024-10-15 12:24:36 -05:00 committed by GitHub
parent fead4749b9
commit 12fea5b868
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -174,10 +174,10 @@ class ChatSambaNovaCloud(BaseChatModel):
temperature: float = Field(default=0.7) temperature: float = Field(default=0.7)
"""model temperature""" """model temperature"""
top_p: Optional[float] = Field() top_p: Optional[float] = Field(default=None)
"""model top p""" """model top p"""
top_k: Optional[int] = Field() top_k: Optional[int] = Field(default=None)
"""model top k""" """model top k"""
stream_options: dict = Field(default={"include_usage": True}) stream_options: dict = Field(default={"include_usage": True})
@ -593,7 +593,7 @@ class ChatSambaStudio(BaseChatModel):
streaming_url: str = Field(default="", exclude=True) streaming_url: str = Field(default="", exclude=True)
"""SambaStudio streaming Url""" """SambaStudio streaming Url"""
model: Optional[str] = Field() model: Optional[str] = Field(default=None)
"""The name of the model or expert to use (for CoE endpoints)""" """The name of the model or expert to use (for CoE endpoints)"""
streaming: bool = Field(default=False) streaming: bool = Field(default=False)
@ -605,16 +605,16 @@ class ChatSambaStudio(BaseChatModel):
temperature: Optional[float] = Field(default=0.7) temperature: Optional[float] = Field(default=0.7)
"""model temperature""" """model temperature"""
top_p: Optional[float] = Field() top_p: Optional[float] = Field(default=None)
"""model top p""" """model top p"""
top_k: Optional[int] = Field() top_k: Optional[int] = Field(default=None)
"""model top k""" """model top k"""
do_sample: Optional[bool] = Field() do_sample: Optional[bool] = Field(default=None)
"""whether to do sampling""" """whether to do sampling"""
process_prompt: Optional[bool] = Field() process_prompt: Optional[bool] = Field(default=True)
"""whether process prompt (for CoE generic v1 and v2 endpoints)""" """whether process prompt (for CoE generic v1 and v2 endpoints)"""
stream_options: dict = Field(default={"include_usage": True}) stream_options: dict = Field(default={"include_usage": True})
@ -1012,6 +1012,16 @@ class ChatSambaStudio(BaseChatModel):
"system_fingerprint": data["system_fingerprint"], "system_fingerprint": data["system_fingerprint"],
"created": data["created"], "created": data["created"],
} }
if data.get("usage") is not None:
content = ""
id = data["id"]
metadata = {
"finish_reason": finish_reason,
"usage": data.get("usage"),
"model_name": data["model"],
"system_fingerprint": data["system_fingerprint"],
"created": data["created"],
}
yield AIMessageChunk( yield AIMessageChunk(
content=content, content=content,
id=id, id=id,