mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-04 19:17:59 +00:00
embed4all: small fixes related to nomic client local embeddings (#2213)
* actually submit larger batches with increased n_ctx * fix crash when llama_tokenize returns no tokens Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
1e4c62027b
commit
459289b94c
@ -325,7 +325,7 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
|||||||
bool isEmbedding = is_embedding_arch(llama_model_arch(d_ptr->model));
|
bool isEmbedding = is_embedding_arch(llama_model_arch(d_ptr->model));
|
||||||
const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
|
const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
|
||||||
if (isEmbedding) {
|
if (isEmbedding) {
|
||||||
d_ptr->ctx_params.n_batch = n_ctx_train;
|
d_ptr->ctx_params.n_batch = n_ctx;
|
||||||
} else {
|
} else {
|
||||||
if (n_ctx > n_ctx_train) {
|
if (n_ctx > n_ctx_train) {
|
||||||
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
|
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
|
||||||
@ -734,7 +734,7 @@ void LLamaModel::embedInternal(
|
|||||||
) {
|
) {
|
||||||
typedef std::vector<LLModel::Token> TokenString;
|
typedef std::vector<LLModel::Token> TokenString;
|
||||||
static constexpr int32_t atlasMaxLength = 8192;
|
static constexpr int32_t atlasMaxLength = 8192;
|
||||||
static constexpr int chunkOverlap = 8; // Atlas overlaps n_batch-sized chunks of input by 8 tokens
|
static constexpr int chunkOverlap = 8; // Atlas overlaps chunks of input by 8 tokens
|
||||||
|
|
||||||
const llama_token bos_token = llama_token_bos(d_ptr->model);
|
const llama_token bos_token = llama_token_bos(d_ptr->model);
|
||||||
const llama_token eos_token = llama_token_eos(d_ptr->model);
|
const llama_token eos_token = llama_token_eos(d_ptr->model);
|
||||||
@ -751,8 +751,12 @@ void LLamaModel::embedInternal(
|
|||||||
|
|
||||||
tokens.resize(text.length()+4);
|
tokens.resize(text.length()+4);
|
||||||
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false);
|
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false);
|
||||||
|
if (n_tokens) {
|
||||||
assert(useEOS == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
|
assert(useEOS == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
|
||||||
tokens.resize(n_tokens - useEOS); // erase EOS/SEP
|
tokens.resize(n_tokens - useEOS); // erase EOS/SEP
|
||||||
|
} else {
|
||||||
|
tokens.clear();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// tokenize the texts
|
// tokenize the texts
|
||||||
@ -786,9 +790,14 @@ void LLamaModel::embedInternal(
|
|||||||
tokenize(prefix + ':', prefixTokens, true);
|
tokenize(prefix + ':', prefixTokens, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// n_ctx_train: max sequence length of model (RoPE scaling not implemented)
|
||||||
|
const uint32_t n_ctx_train = llama_n_ctx_train(d_ptr->model);
|
||||||
|
// n_batch (equals n_ctx): max tokens per call to llama_decode (one more more sequences)
|
||||||
const uint32_t n_batch = llama_n_batch(d_ptr->ctx);
|
const uint32_t n_batch = llama_n_batch(d_ptr->ctx);
|
||||||
const uint32_t max_len = n_batch - (prefixTokens.size() + useEOS); // minus BOS/CLS and EOS/SEP
|
|
||||||
if (chunkOverlap >= max_len) {
|
// effective sequence length minus prefix and SEP token
|
||||||
|
const uint32_t max_len = std::min(n_ctx_train, n_batch) - (prefixTokens.size() + useEOS);
|
||||||
|
if (max_len <= chunkOverlap) {
|
||||||
throw std::logic_error("max chunk length of " + std::to_string(max_len) + " is smaller than overlap of " +
|
throw std::logic_error("max chunk length of " + std::to_string(max_len) + " is smaller than overlap of " +
|
||||||
std::to_string(chunkOverlap) + " tokens");
|
std::to_string(chunkOverlap) + " tokens");
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,7 @@ def get_long_description():
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=package_name,
|
name=package_name,
|
||||||
version="2.4.0",
|
version="2.4.1",
|
||||||
description="Python bindings for GPT4All",
|
description="Python bindings for GPT4All",
|
||||||
long_description=get_long_description(),
|
long_description=get_long_description(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
Loading…
Reference in New Issue
Block a user