Track ChatFireworks time to first_token (#11672)

This commit is contained in:
Erick Friis 2023-10-11 13:37:03 -07:00 committed by GitHub
parent 2c1e735403
commit 28ee6a7c12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 2 deletions

View File

@ -96,7 +96,10 @@ class ChatFireworks(BaseChatModel):
try: try:
import fireworks.client import fireworks.client
except ImportError as e: except ImportError as e:
raise ImportError("") from e raise ImportError(
"Could not import fireworks-ai python package. "
"Please install it with `pip install fireworks-ai`."
) from e
fireworks_api_key = get_from_dict_or_env( fireworks_api_key = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY" values, "fireworks_api_key", "FIREWORKS_API_KEY"
) )
@ -194,6 +197,8 @@ class ChatFireworks(BaseChatModel):
) )
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info) yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=chunk)
async def _astream( async def _astream(
self, self,
@ -221,6 +226,8 @@ class ChatFireworks(BaseChatModel):
) )
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info) yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
if run_manager:
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
def completion_with_retry( def completion_with_retry(

View File

@ -45,7 +45,10 @@ class Fireworks(LLM):
try: try:
import fireworks.client import fireworks.client
except ImportError as e: except ImportError as e:
raise ImportError("") from e raise ImportError(
"Could not import fireworks-ai python package. "
"Please install it with `pip install fireworks-ai`."
) from e
fireworks_api_key = get_from_dict_or_env( fireworks_api_key = get_from_dict_or_env(
values, "fireworks_api_key", "FIREWORKS_API_KEY" values, "fireworks_api_key", "FIREWORKS_API_KEY"
) )
@ -113,6 +116,8 @@ class Fireworks(LLM):
): ):
chunk = _stream_response_to_generation_chunk(stream_resp) chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
async def _astream( async def _astream(
self, self,
@ -132,6 +137,8 @@ class Fireworks(LLM):
): ):
chunk = _stream_response_to_generation_chunk(stream_resp) chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
def stream( def stream(
self, self,