explicitly check openllm return type (#10560)

cc @aarnphm
This commit is contained in:
Bagatur 2023-09-13 14:13:15 -07:00 committed by GitHub
parent 85e05fa5d6
commit 49694f6a3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -265,16 +265,19 @@ class OpenLLM(LLM):
self._identifying_params["model_name"], **copied self._identifying_params["model_name"], **copied
) )
if self._client: if self._client:
o = self._client.query(prompt, **config.model_dump(flatten=True)) res = self._client.query(prompt, **config.model_dump(flatten=True))
if isinstance(o, dict) and "text" in o:
return o["text"]
return o
else: else:
assert self._runner is not None assert self._runner is not None
o = self._runner(prompt, **config.model_dump(flatten=True)) res = self._runner(prompt, **config.model_dump(flatten=True))
if isinstance(o, dict) and "text" in o: if isinstance(res, dict) and "text" in res:
return o["text"] return res["text"]
return o elif isinstance(res, str):
return res
else:
raise ValueError(
"Expected result to be a dict with key 'text' or a string. "
f"Received {res}"
)
async def _acall( async def _acall(
self, self,
@ -297,12 +300,9 @@ class OpenLLM(LLM):
self._identifying_params["model_name"], **copied self._identifying_params["model_name"], **copied
) )
if self._client: if self._client:
o = await self._client.acall( res = await self._client.acall(
"generate", prompt, **config.model_dump(flatten=True) "generate", prompt, **config.model_dump(flatten=True)
) )
if isinstance(o, dict) and "text" in o:
return o["text"]
return o
else: else:
assert self._runner is not None assert self._runner is not None
( (
@ -313,9 +313,16 @@ class OpenLLM(LLM):
generated_result = await self._runner.generate.async_run( generated_result = await self._runner.generate.async_run(
prompt, **generate_kwargs prompt, **generate_kwargs
) )
o = self._runner.llm.postprocess_generate( res = self._runner.llm.postprocess_generate(
prompt, generated_result, **postprocess_kwargs prompt, generated_result, **postprocess_kwargs
) )
if isinstance(o, dict) and "text" in o:
return o["text"] if isinstance(res, dict) and "text" in res:
return o return res["text"]
elif isinstance(res, str):
return res
else:
raise ValueError(
"Expected result to be a dict with key 'text' or a string. "
f"Received {res}"
)