mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
Feat bedrock cohere support (#11230)
**Description:** Added support for Cohere command model via Bedrock. With this change it is now possible to use the `cohere.command-text-v14` model via Bedrock API. About Streaming: Cohere model outputs 2 additional chunks at the end of the text being generated via streaming: a chunk containing the text `<EOS_TOKEN>`, and a chunk indicating the end of the stream. In this implementation I chose to ignore both chunks. An alternative solution could be to replace `<EOS_TOKEN>` with `\n` Tests: manually tested that the new model work with both `llm.generate()` and `llm.stream()`. Tested with `temperature`, `p` and `stop` parameters. **Issue:** #11181 **Dependencies:** No new dependencies **Tag maintainer:** @baskaryan **Twitter handle:** mangelino --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
37f2f71156
commit
2f83350eac
@ -65,6 +65,7 @@ class LLMInputOutputAdapter:
|
|||||||
provider_to_output_key_map = {
|
provider_to_output_key_map = {
|
||||||
"anthropic": "completion",
|
"anthropic": "completion",
|
||||||
"amazon": "outputText",
|
"amazon": "outputText",
|
||||||
|
"cohere": "text",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -74,7 +75,7 @@ class LLMInputOutputAdapter:
|
|||||||
input_body = {**model_kwargs}
|
input_body = {**model_kwargs}
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
input_body["prompt"] = _human_assistant_format(prompt)
|
input_body["prompt"] = _human_assistant_format(prompt)
|
||||||
elif provider == "ai21":
|
elif provider == "ai21" or provider == "cohere":
|
||||||
input_body["prompt"] = prompt
|
input_body["prompt"] = prompt
|
||||||
elif provider == "amazon":
|
elif provider == "amazon":
|
||||||
input_body = dict()
|
input_body = dict()
|
||||||
@ -98,6 +99,8 @@ class LLMInputOutputAdapter:
|
|||||||
|
|
||||||
if provider == "ai21":
|
if provider == "ai21":
|
||||||
return response_body.get("completions")[0].get("data").get("text")
|
return response_body.get("completions")[0].get("data").get("text")
|
||||||
|
elif provider == "cohere":
|
||||||
|
return response_body.get("generations")[0].get("text")
|
||||||
else:
|
else:
|
||||||
return response_body.get("results")[0].get("outputText")
|
return response_body.get("results")[0].get("outputText")
|
||||||
|
|
||||||
@ -119,6 +122,12 @@ class LLMInputOutputAdapter:
|
|||||||
chunk = event.get("chunk")
|
chunk = event.get("chunk")
|
||||||
if chunk:
|
if chunk:
|
||||||
chunk_obj = json.loads(chunk.get("bytes").decode())
|
chunk_obj = json.loads(chunk.get("bytes").decode())
|
||||||
|
if provider == "cohere" and (
|
||||||
|
chunk_obj["is_finished"]
|
||||||
|
or chunk_obj[cls.provider_to_output_key_map[provider]]
|
||||||
|
== "<EOS_TOKEN>"
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
# chunk obj format varies with provider
|
# chunk obj format varies with provider
|
||||||
yield GenerationChunk(
|
yield GenerationChunk(
|
||||||
@ -159,6 +168,7 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
"anthropic": "stop_sequences",
|
"anthropic": "stop_sequences",
|
||||||
"amazon": "stopSequences",
|
"amazon": "stopSequences",
|
||||||
"ai21": "stop_sequences",
|
"ai21": "stop_sequences",
|
||||||
|
"cohere": "stop_sequences",
|
||||||
}
|
}
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
@ -259,9 +269,10 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
|
|
||||||
# stop sequence from _generate() overrides
|
# stop sequence from _generate() overrides
|
||||||
# stop sequences in the class attribute
|
# stop sequences in the class attribute
|
||||||
_model_kwargs[
|
_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
|
||||||
self.provider_stop_sequence_key_name_map.get(provider),
|
|
||||||
] = stop
|
if provider == "cohere":
|
||||||
|
_model_kwargs["stream"] = True
|
||||||
|
|
||||||
params = {**_model_kwargs, **kwargs}
|
params = {**_model_kwargs, **kwargs}
|
||||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||||
|
Loading…
Reference in New Issue
Block a user