mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-31 05:49:56 +00:00 
			
		
		
		
	* [pre-commit.ci] pre-commit autoupdate updates: - [github.com/PyCQA/autoflake: v2.2.1 → v2.3.1](https://github.com/PyCQA/autoflake/compare/v2.2.1...v2.3.1) - [github.com/pycqa/isort: 5.12.0 → 5.13.2](https://github.com/pycqa/isort/compare/5.12.0...5.13.2) - [github.com/psf/black-pre-commit-mirror: 23.9.1 → 24.4.2](https://github.com/psf/black-pre-commit-mirror/compare/23.9.1...24.4.2) - [github.com/pre-commit/mirrors-clang-format: v13.0.1 → v18.1.7](https://github.com/pre-commit/mirrors-clang-format/compare/v13.0.1...v18.1.7) - [github.com/pre-commit/pre-commit-hooks: v4.3.0 → v4.6.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.3.0...v4.6.0) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
		
			
				
	
	
		
			691 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			691 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /*
 | |
|  coding=utf-8
 | |
|  Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
 | |
| 
 | |
|  Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  you may not use this file except in compliance with the License.
 | |
|  You may obtain a copy of the License at
 | |
| 
 | |
|      http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
|  Unless required by applicable law or agreed to in writing, software
 | |
|  distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  See the License for the specific language governing permissions and
 | |
|  limitations under the License.
 | |
|  */
 | |
| 
 | |
| /* Helper methods for fast index mapping builds */
 | |
| 
 | |
| #include <math.h>
 | |
| #include <pybind11/numpy.h>
 | |
| #include <pybind11/pybind11.h>
 | |
| 
 | |
| #include <algorithm>
 | |
| #include <iostream>
 | |
| #include <limits>
 | |
| #include <random>
 | |
| #include <stdexcept>
 | |
| 
 | |
| namespace py = pybind11;
 | |
| using namespace std;
 | |
| 
 | |
| const int32_t LONG_SENTENCE_LEN = 512;
 | |
| 
 | |
| void build_blending_indices(py::array_t<uint8_t>& dataset_index,
 | |
|                             py::array_t<int64_t>& dataset_sample_index,
 | |
|                             const py::array_t<double>& weights,
 | |
|                             const int32_t num_datasets, const int64_t size,
 | |
|                             const bool verbose) {
 | |
|   /* Given multiple datasets and a weighting array, build samples
 | |
|    such that it follows those wieghts.*/
 | |
| 
 | |
|   if (verbose) {
 | |
|     std::cout << "> building indices for blendable datasets ..." << std::endl;
 | |
|   }
 | |
| 
 | |
|   // Get the pointer access without the checks.
 | |
|   auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
 | |
|   auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
 | |
|   auto weights_ptr = weights.unchecked<1>();
 | |
| 
 | |
|   // Initialize buffer for number of samples used for each dataset.
 | |
|   int64_t current_samples[num_datasets];
 | |
|   for (int64_t i = 0; i < num_datasets; ++i) {
 | |
|     current_samples[i] = 0;
 | |
|   }
 | |
| 
 | |
|   // For each sample:
 | |
|   for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
 | |
|     // Determine where the max error in sampling is happening.
 | |
|     auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
 | |
|     int64_t max_error_index = 0;
 | |
|     double max_error = weights_ptr[0] * sample_idx_double -
 | |
|                        static_cast<double>(current_samples[0]);
 | |
|     for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
 | |
|       double error = weights_ptr[dataset_idx] * sample_idx_double -
 | |
|                      static_cast<double>(current_samples[dataset_idx]);
 | |
|       if (error > max_error) {
 | |
|         max_error = error;
 | |
|         max_error_index = dataset_idx;
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     // Populate the indices.
 | |
|     dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
 | |
|     dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
 | |
| 
 | |
|     // Update the total samples.
 | |
|     current_samples[max_error_index] += 1;
 | |
|   }
 | |
| 
 | |
|   // print info
 | |
|   if (verbose) {
 | |
|     std::cout << " > sample ratios:" << std::endl;
 | |
|     for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
 | |
|       auto ratio = static_cast<double>(current_samples[dataset_idx]) /
 | |
|                    static_cast<double>(size);
 | |
|       std::cout << "   dataset " << dataset_idx
 | |
|                 << ", input: " << weights_ptr[dataset_idx]
 | |
|                 << ", achieved: " << ratio << std::endl;
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
 | |
|                            const py::array_t<int32_t>& doc_idx_,
 | |
|                            const int32_t seq_length, const int32_t num_epochs,
 | |
|                            const int64_t tokens_per_epoch) {
 | |
|   /* Sample index (sample_idx) is used for gpt2 like dataset for which
 | |
|      the documents are flattened and the samples are built based on this
 | |
|      1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
 | |
|      where [..., 0] contains the index into `doc_idx` and [..., 1] is the
 | |
|      starting offset in that document.*/
 | |
| 
 | |
|   // Consistency checks.
 | |
|   assert(seq_length > 1);
 | |
|   assert(num_epochs > 0);
 | |
|   assert(tokens_per_epoch > 1);
 | |
| 
 | |
|   // Remove bound checks.
 | |
|   auto sizes = sizes_.unchecked<1>();
 | |
|   auto doc_idx = doc_idx_.unchecked<1>();
 | |
| 
 | |
|   // Mapping and it's length (1D).
 | |
|   int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
 | |
|   int32_t* sample_idx = new int32_t[2 * (num_samples + 1)];
 | |
| 
 | |
|   cout << "    using:" << endl << std::flush;
 | |
|   cout << "     number of documents:       " << doc_idx_.shape(0) / num_epochs
 | |
|        << endl
 | |
|        << std::flush;
 | |
|   cout << "     number of epochs:          " << num_epochs << endl
 | |
|        << std::flush;
 | |
|   cout << "     sequence length:           " << seq_length << endl
 | |
|        << std::flush;
 | |
|   cout << "     total number of samples:   " << num_samples << endl
 | |
|        << std::flush;
 | |
| 
 | |
|   // Index into sample_idx.
 | |
|   int64_t sample_index = 0;
 | |
|   // Index into doc_idx.
 | |
|   int64_t doc_idx_index = 0;
 | |
|   // Begining offset for each document.
 | |
|   int32_t doc_offset = 0;
 | |
|   // Start with first document and no offset.
 | |
|   sample_idx[2 * sample_index] = doc_idx_index;
 | |
|   sample_idx[2 * sample_index + 1] = doc_offset;
 | |
|   ++sample_index;
 | |
| 
 | |
|   while (sample_index <= num_samples) {
 | |
|     // Start with a fresh sequence.
 | |
|     int32_t remaining_seq_length = seq_length + 1;
 | |
|     while (remaining_seq_length != 0) {
 | |
|       // Get the document length.
 | |
|       auto doc_id = doc_idx[doc_idx_index];
 | |
|       auto doc_length = sizes[doc_id] - doc_offset;
 | |
|       // And add it to the current sequence.
 | |
|       remaining_seq_length -= doc_length;
 | |
|       // If we have more than a full sequence, adjust offset and set
 | |
|       // remaining length to zero so we return from the while loop.
 | |
|       // Note that -1 here is for the same reason we have -1 in
 | |
|       // `_num_epochs` calculations.
 | |
|       if (remaining_seq_length <= 0) {
 | |
|         doc_offset += (remaining_seq_length + doc_length - 1);
 | |
|         remaining_seq_length = 0;
 | |
|       } else {
 | |
|         // Otherwise, start from the begining of the next document.
 | |
|         ++doc_idx_index;
 | |
|         doc_offset = 0;
 | |
|       }
 | |
|     }
 | |
|     // Record the sequence.
 | |
|     sample_idx[2 * sample_index] = doc_idx_index;
 | |
|     sample_idx[2 * sample_index + 1] = doc_offset;
 | |
|     ++sample_index;
 | |
|   }
 | |
| 
 | |
|   // Method to deallocate memory.
 | |
|   py::capsule free_when_done(sample_idx, [](void* mem_) {
 | |
|     int32_t* mem = reinterpret_cast<int32_t*>(mem_);
 | |
|     delete[] mem;
 | |
|   });
 | |
| 
 | |
|   // Return the numpy array.
 | |
|   const auto byte_size = sizeof(int32_t);
 | |
|   return py::array(std::vector<int64_t>{num_samples + 1, 2},  // shape
 | |
|                    {2 * byte_size, byte_size},  // C-style contiguous strides
 | |
|                    sample_idx,                  // the data pointer
 | |
|                    free_when_done);             // numpy array references
 | |
| }
 | |
| 
 | |
| inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
 | |
|                                      const int32_t max_length,
 | |
|                                      std::mt19937& rand32_gen) {
 | |
|   /* Training sample length. */
 | |
|   if (short_seq_ratio == 0) {
 | |
|     return max_length;
 | |
|   }
 | |
|   const auto random_number = rand32_gen();
 | |
|   if ((random_number % short_seq_ratio) == 0) {
 | |
|     return 2 + random_number % (max_length - 1);
 | |
|   }
 | |
|   return max_length;
 | |
| }
 | |
| 
 | |
| template <typename DocIdx>
 | |
| py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
 | |
|                              const py::array_t<int32_t>& sizes_,
 | |
|                              const int32_t num_epochs,
 | |
|                              const uint64_t max_num_samples,
 | |
|                              const int32_t max_seq_length,
 | |
|                              const double short_seq_prob, const int32_t seed,
 | |
|                              const bool verbose, const int32_t min_num_sent) {
 | |
|   /* Build a mapping of (start-index, end-index, sequence-length) where
 | |
|      start and end index are the indices of the sentences in the sample
 | |
|      and sequence-length is the target sequence length.
 | |
|   */
 | |
| 
 | |
|   // Consistency checks.
 | |
|   assert(num_epochs > 0);
 | |
|   assert(max_seq_length > 1);
 | |
|   assert(short_seq_prob >= 0.0);
 | |
|   assert(short_seq_prob <= 1.0);
 | |
|   assert(seed > 0);
 | |
| 
 | |
|   // Remove bound checks.
 | |
|   auto docs = docs_.unchecked<1>();
 | |
|   auto sizes = sizes_.unchecked<1>();
 | |
| 
 | |
|   // For efficiency, convert probability to ratio. Note: rand() generates int.
 | |
|   int32_t short_seq_ratio = 0;
 | |
|   if (short_seq_prob > 0) {
 | |
|     short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
 | |
|   }
 | |
| 
 | |
|   if (verbose) {
 | |
|     const auto sent_start_index = docs[0];
 | |
|     const auto sent_end_index = docs[docs_.shape(0) - 1];
 | |
|     const auto num_sentences = sent_end_index - sent_start_index;
 | |
|     cout << "    using:" << endl << std::flush;
 | |
|     cout << "     number of documents:            " << docs_.shape(0) - 1
 | |
|          << endl
 | |
|          << std::flush;
 | |
|     cout << "     sentences range:                [" << sent_start_index << ", "
 | |
|          << sent_end_index << ")" << endl
 | |
|          << std::flush;
 | |
|     cout << "     total number of sentences:      " << num_sentences << endl
 | |
|          << std::flush;
 | |
|     cout << "     number of epochs:               " << num_epochs << endl
 | |
|          << std::flush;
 | |
|     cout << "     maximum number of samples:      " << max_num_samples << endl
 | |
|          << std::flush;
 | |
|     cout << "     maximum sequence length:        " << max_seq_length << endl
 | |
|          << std::flush;
 | |
|     cout << "     short sequence probability:     " << short_seq_prob << endl
 | |
|          << std::flush;
 | |
|     cout << "     short sequence ration (1/prob): " << short_seq_ratio << endl
 | |
|          << std::flush;
 | |
|     cout << "     seed:                           " << seed << endl
 | |
|          << std::flush;
 | |
|   }
 | |
| 
 | |
|   // Mapping and it's length (1D).
 | |
|   int64_t num_samples = -1;
 | |
|   DocIdx* maps = NULL;
 | |
| 
 | |
|   // Perform two iterations, in the first iteration get the size
 | |
|   // and allocate memory and in the second iteration populate the map.
 | |
|   bool second = false;
 | |
|   for (int32_t iteration = 0; iteration < 2; ++iteration) {
 | |
|     // Set the seed so both iterations produce the same results.
 | |
|     std::mt19937 rand32_gen(seed);
 | |
| 
 | |
|     // Set the flag on second iteration.
 | |
|     second = (iteration == 1);
 | |
| 
 | |
|     // Counters:
 | |
|     uint64_t empty_docs = 0;
 | |
|     uint64_t one_sent_docs = 0;
 | |
|     uint64_t long_sent_docs = 0;
 | |
| 
 | |
|     // Current map index.
 | |
|     uint64_t map_index = 0;
 | |
| 
 | |
|     // For each epoch:
 | |
|     for (int32_t epoch = 0; epoch < num_epochs; ++epoch) {
 | |
|       if (map_index >= max_num_samples) {
 | |
|         if (verbose && (!second)) {
 | |
|           cout << "    reached " << max_num_samples << " samples after "
 | |
|                << epoch << " epochs ..." << endl
 | |
|                << std::flush;
 | |
|         }
 | |
|         break;
 | |
|       }
 | |
|       // For each document:
 | |
|       for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) {
 | |
|         // Document sentences are in [sent_index_first, sent_index_last)
 | |
|         const auto sent_index_first = docs[doc];
 | |
|         const auto sent_index_last = docs[doc + 1];
 | |
| 
 | |
|         // At the begining of the document previous index is the
 | |
|         // start index.
 | |
|         auto prev_start_index = sent_index_first;
 | |
| 
 | |
|         // Remaining documents.
 | |
|         auto num_remain_sent = sent_index_last - sent_index_first;
 | |
| 
 | |
|         // Some bookkeeping
 | |
|         if ((epoch == 0) && (!second)) {
 | |
|           if (num_remain_sent == 0) {
 | |
|             ++empty_docs;
 | |
|           }
 | |
|           if (num_remain_sent == 1) {
 | |
|             ++one_sent_docs;
 | |
|           }
 | |
|         }
 | |
| 
 | |
|         // Detect documents with long sentences.
 | |
|         bool contains_long_sentence = false;
 | |
|         if (num_remain_sent > 1) {
 | |
|           for (auto sent_index = sent_index_first; sent_index < sent_index_last;
 | |
|                ++sent_index) {
 | |
|             if (sizes[sent_index] > LONG_SENTENCE_LEN) {
 | |
|               if ((epoch == 0) && (!second)) {
 | |
|                 ++long_sent_docs;
 | |
|               }
 | |
|               contains_long_sentence = true;
 | |
|               break;
 | |
|             }
 | |
|           }
 | |
|         }
 | |
| 
 | |
|         // If we have more than two sentences.
 | |
|         if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
 | |
|           // Set values.
 | |
|           auto seq_len = int32_t{0};
 | |
|           auto num_sent = int32_t{0};
 | |
|           auto target_seq_len = get_target_sample_len(
 | |
|               short_seq_ratio, max_seq_length, rand32_gen);
 | |
| 
 | |
|           // Loop through sentences.
 | |
|           for (auto sent_index = sent_index_first; sent_index < sent_index_last;
 | |
|                ++sent_index) {
 | |
|             // Add the size and number of sentences.
 | |
|             seq_len += sizes[sent_index];
 | |
|             ++num_sent;
 | |
|             --num_remain_sent;
 | |
| 
 | |
|             // If we have reached the target length.
 | |
|             // and if not only one sentence is left in the document.
 | |
|             // and if we have at least two sentneces.
 | |
|             // and if we have reached end of the document.
 | |
|             if (((seq_len >= target_seq_len) && (num_remain_sent > 1) &&
 | |
|                  (num_sent >= min_num_sent)) ||
 | |
|                 (num_remain_sent == 0)) {
 | |
|               // Check for overflow.
 | |
|               if ((3 * map_index + 2) > std::numeric_limits<int64_t>::max()) {
 | |
|                 cout << "number of samples exceeded maximum "
 | |
|                      << "allowed by type int64: "
 | |
|                      << std::numeric_limits<int64_t>::max() << endl;
 | |
|                 throw std::overflow_error("Number of samples");
 | |
|               }
 | |
| 
 | |
|               // Populate the map.
 | |
|               if (second) {
 | |
|                 const auto map_index_0 = 3 * map_index;
 | |
|                 maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
 | |
|                 maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
 | |
|                 maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
 | |
|               }
 | |
| 
 | |
|               // Update indices / counters.
 | |
|               ++map_index;
 | |
|               prev_start_index = sent_index + 1;
 | |
|               target_seq_len = get_target_sample_len(
 | |
|                   short_seq_ratio, max_seq_length, rand32_gen);
 | |
|               seq_len = 0;
 | |
|               num_sent = 0;
 | |
|             }
 | |
| 
 | |
|           }  // for (auto sent_index=sent_index_first; ...
 | |
|         }  // if (num_remain_sent > 1) {
 | |
|       }  // for (int doc=0; doc < num_docs; ++doc) {
 | |
|     }  // for (int epoch=0; epoch < num_epochs; ++epoch) {
 | |
| 
 | |
|     if (!second) {
 | |
|       if (verbose) {
 | |
|         cout << "   number of empty documents: " << empty_docs << endl
 | |
|              << std::flush;
 | |
|         cout << "   number of documents with one sentence: " << one_sent_docs
 | |
|              << endl
 | |
|              << std::flush;
 | |
|         cout << "   number of documents with long sentences: " << long_sent_docs
 | |
|              << endl
 | |
|              << std::flush;
 | |
|         cout << "   will create mapping for " << map_index << " samples" << endl
 | |
|              << std::flush;
 | |
|       }
 | |
|       assert(maps == NULL);
 | |
|       assert(num_samples < 0);
 | |
|       maps = new DocIdx[3 * map_index];
 | |
|       num_samples = static_cast<int64_t>(map_index);
 | |
|     }
 | |
| 
 | |
|   }  // for (int iteration=0; iteration < 2; ++iteration) {
 | |
| 
 | |
|   // Shuffle.
 | |
|   // We need a 64 bit random number generator as we might have more
 | |
|   // than 2 billion samples.
 | |
|   std::mt19937_64 rand64_gen(seed + 1);
 | |
|   for (auto i = (num_samples - 1); i > 0; --i) {
 | |
|     const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
 | |
|     const auto i0 = 3 * i;
 | |
|     const auto j0 = 3 * j;
 | |
|     // Swap values.
 | |
|     swap(maps[i0], maps[j0]);
 | |
|     swap(maps[i0 + 1], maps[j0 + 1]);
 | |
|     swap(maps[i0 + 2], maps[j0 + 2]);
 | |
|   }
 | |
| 
 | |
|   // Method to deallocate memory.
 | |
|   py::capsule free_when_done(maps, [](void* mem_) {
 | |
|     DocIdx* mem = reinterpret_cast<DocIdx*>(mem_);
 | |
|     delete[] mem;
 | |
|   });
 | |
| 
 | |
|   // Return the numpy array.
 | |
|   const auto byte_size = sizeof(DocIdx);
 | |
|   return py::array(std::vector<int64_t>{num_samples, 3},  // shape
 | |
|                    {3 * byte_size, byte_size},  // C-style contiguous strides
 | |
|                    maps,                        // the data pointer
 | |
|                    free_when_done);             // numpy array references
 | |
| }
 | |
| 
 | |
| py::array build_mapping(const py::array_t<int64_t>& docs_,
 | |
|                         const py::array_t<int>& sizes_, const int num_epochs,
 | |
|                         const uint64_t max_num_samples,
 | |
|                         const int max_seq_length, const double short_seq_prob,
 | |
|                         const int seed, const bool verbose,
 | |
|                         const int32_t min_num_sent) {
 | |
|   if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
 | |
|     if (verbose) {
 | |
|       cout << "    using uint64 for data mapping..." << endl << std::flush;
 | |
|     }
 | |
|     return build_mapping_impl<uint64_t>(
 | |
|         docs_, sizes_, num_epochs, max_num_samples, max_seq_length,
 | |
|         short_seq_prob, seed, verbose, min_num_sent);
 | |
|   } else {
 | |
|     if (verbose) {
 | |
|       cout << "    using uint32 for data mapping..." << endl << std::flush;
 | |
|     }
 | |
|     return build_mapping_impl<uint32_t>(
 | |
|         docs_, sizes_, num_epochs, max_num_samples, max_seq_length,
 | |
|         short_seq_prob, seed, verbose, min_num_sent);
 | |
|   }
 | |
| }
 | |
| 
 | |
| template <typename DocIdx>
 | |
| py::array build_blocks_mapping_impl(
 | |
|     const py::array_t<int64_t>& docs_, const py::array_t<int32_t>& sizes_,
 | |
|     const py::array_t<int32_t>& titles_sizes_, const int32_t num_epochs,
 | |
|     const uint64_t max_num_samples, const int32_t max_seq_length,
 | |
|     const int32_t seed, const bool verbose, const bool use_one_sent_blocks) {
 | |
|   /* Build a mapping of (start-index, end-index, sequence-length) where
 | |
|      start and end index are the indices of the sentences in the sample
 | |
|      and sequence-length is the target sequence length.
 | |
|   */
 | |
| 
 | |
|   // Consistency checks.
 | |
|   assert(num_epochs > 0);
 | |
|   assert(max_seq_length > 1);
 | |
|   assert(seed > 0);
 | |
| 
 | |
|   // Remove bound checks.
 | |
|   auto docs = docs_.unchecked<1>();
 | |
|   auto sizes = sizes_.unchecked<1>();
 | |
|   auto titles_sizes = titles_sizes_.unchecked<1>();
 | |
| 
 | |
|   if (verbose) {
 | |
|     const auto sent_start_index = docs[0];
 | |
|     const auto sent_end_index = docs[docs_.shape(0) - 1];
 | |
|     const auto num_sentences = sent_end_index - sent_start_index;
 | |
|     cout << "    using:" << endl << std::flush;
 | |
|     cout << "     number of documents:            " << docs_.shape(0) - 1
 | |
|          << endl
 | |
|          << std::flush;
 | |
|     cout << "     sentences range:                [" << sent_start_index << ", "
 | |
|          << sent_end_index << ")" << endl
 | |
|          << std::flush;
 | |
|     cout << "     total number of sentences:      " << num_sentences << endl
 | |
|          << std::flush;
 | |
|     cout << "     number of epochs:               " << num_epochs << endl
 | |
|          << std::flush;
 | |
|     cout << "     maximum number of samples:      " << max_num_samples << endl
 | |
|          << std::flush;
 | |
|     cout << "     maximum sequence length:        " << max_seq_length << endl
 | |
|          << std::flush;
 | |
|     cout << "     seed:                           " << seed << endl
 | |
|          << std::flush;
 | |
|   }
 | |
| 
 | |
|   // Mapping and its length (1D).
 | |
|   int64_t num_samples = -1;
 | |
|   DocIdx* maps = NULL;
 | |
| 
 | |
|   // Acceptable number of sentences per block.
 | |
|   int min_num_sent = 2;
 | |
|   if (use_one_sent_blocks) {
 | |
|     min_num_sent = 1;
 | |
|   }
 | |
| 
 | |
|   // Perform two iterations, in the first iteration get the size
 | |
|   // and allocate memory and in the second iteration populate the map.
 | |
|   bool second = false;
 | |
|   for (int32_t iteration = 0; iteration < 2; ++iteration) {
 | |
|     // Set the flag on second iteration.
 | |
|     second = (iteration == 1);
 | |
| 
 | |
|     // Current map index.
 | |
|     uint64_t map_index = 0;
 | |
| 
 | |
|     uint64_t empty_docs = 0;
 | |
|     uint64_t one_sent_docs = 0;
 | |
|     uint64_t long_sent_docs = 0;
 | |
|     // For each epoch:
 | |
|     for (int32_t epoch = 0; epoch < num_epochs; ++epoch) {
 | |
|       // assign every block a unique id
 | |
|       int32_t block_id = 0;
 | |
| 
 | |
|       if (map_index >= max_num_samples) {
 | |
|         if (verbose && (!second)) {
 | |
|           cout << "    reached " << max_num_samples << " samples after "
 | |
|                << epoch << " epochs ..." << endl
 | |
|                << std::flush;
 | |
|         }
 | |
|         break;
 | |
|       }
 | |
|       // For each document:
 | |
|       for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) {
 | |
|         // Document sentences are in [sent_index_first, sent_index_last)
 | |
|         const auto sent_index_first = docs[doc];
 | |
|         const auto sent_index_last = docs[doc + 1];
 | |
|         const auto target_seq_len = max_seq_length - titles_sizes[doc];
 | |
| 
 | |
|         // At the begining of the document previous index is the
 | |
|         // start index.
 | |
|         auto prev_start_index = sent_index_first;
 | |
| 
 | |
|         // Remaining documents.
 | |
|         auto num_remain_sent = sent_index_last - sent_index_first;
 | |
| 
 | |
|         // Some bookkeeping
 | |
|         if ((epoch == 0) && (!second)) {
 | |
|           if (num_remain_sent == 0) {
 | |
|             ++empty_docs;
 | |
|           }
 | |
|           if (num_remain_sent == 1) {
 | |
|             ++one_sent_docs;
 | |
|           }
 | |
|         }
 | |
|         // Detect documents with long sentences.
 | |
|         bool contains_long_sentence = false;
 | |
|         if (num_remain_sent >= min_num_sent) {
 | |
|           for (auto sent_index = sent_index_first; sent_index < sent_index_last;
 | |
|                ++sent_index) {
 | |
|             if (sizes[sent_index] > LONG_SENTENCE_LEN) {
 | |
|               if ((epoch == 0) && (!second)) {
 | |
|                 ++long_sent_docs;
 | |
|               }
 | |
|               contains_long_sentence = true;
 | |
|               break;
 | |
|             }
 | |
|           }
 | |
|         }
 | |
|         // If we have enough sentences and no long sentences.
 | |
|         if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
 | |
|           // Set values.
 | |
|           auto seq_len = int32_t{0};
 | |
|           auto num_sent = int32_t{0};
 | |
| 
 | |
|           // Loop through sentences.
 | |
|           for (auto sent_index = sent_index_first; sent_index < sent_index_last;
 | |
|                ++sent_index) {
 | |
|             // Add the size and number of sentences.
 | |
|             seq_len += sizes[sent_index];
 | |
|             ++num_sent;
 | |
|             --num_remain_sent;
 | |
| 
 | |
|             // If we have reached the target length.
 | |
|             // and there are an acceptable number of sentences left
 | |
|             // and if we have at least the minimum number of sentences.
 | |
|             // or if we have reached end of the document.
 | |
|             if (((seq_len >= target_seq_len) &&
 | |
|                  (num_remain_sent >= min_num_sent) &&
 | |
|                  (num_sent >= min_num_sent)) ||
 | |
|                 (num_remain_sent == 0)) {
 | |
|               // Populate the map.
 | |
|               if (second) {
 | |
|                 const auto map_index_0 = 4 * map_index;
 | |
|                 // Each sample has 4 items: the starting sentence index, ending
 | |
|                 // sentence index, the index of the document from which the
 | |
|                 // block comes (used for fetching titles) and the unique id of
 | |
|                 // the block (used for creating block indexes)
 | |
| 
 | |
|                 maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
 | |
|                 maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
 | |
|                 maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
 | |
|                 maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
 | |
|               }
 | |
| 
 | |
|               // Update indices / counters.
 | |
|               ++map_index;
 | |
|               ++block_id;
 | |
|               prev_start_index = sent_index + 1;
 | |
|               seq_len = 0;
 | |
|               num_sent = 0;
 | |
|             }
 | |
|           }  // for (auto sent_index=sent_index_first; ...
 | |
|         }  // if (num_remain_sent > 1) {
 | |
|       }  // for (int doc=0; doc < num_docs; ++doc) {
 | |
|     }  // for (int epoch=0; epoch < num_epochs; ++epoch) {
 | |
| 
 | |
|     if (!second) {
 | |
|       if (verbose) {
 | |
|         cout << "   number of empty documents: " << empty_docs << endl
 | |
|              << std::flush;
 | |
|         cout << "   number of documents with one sentence: " << one_sent_docs
 | |
|              << endl
 | |
|              << std::flush;
 | |
|         cout << "   number of documents with long sentences: " << long_sent_docs
 | |
|              << endl
 | |
|              << std::flush;
 | |
|         cout << "   will create mapping for " << map_index << " samples" << endl
 | |
|              << std::flush;
 | |
|       }
 | |
|       assert(maps == NULL);
 | |
|       assert(num_samples < 0);
 | |
|       maps = new DocIdx[4 * map_index];
 | |
|       num_samples = static_cast<int64_t>(map_index);
 | |
|     }
 | |
| 
 | |
|   }  // for (int iteration=0; iteration < 2; ++iteration) {
 | |
| 
 | |
|   // Shuffle.
 | |
|   // We need a 64 bit random number generator as we might have more
 | |
|   // than 2 billion samples.
 | |
|   std::mt19937_64 rand64_gen(seed + 1);
 | |
|   for (auto i = (num_samples - 1); i > 0; --i) {
 | |
|     const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
 | |
|     const auto i0 = 4 * i;
 | |
|     const auto j0 = 4 * j;
 | |
|     // Swap values.
 | |
|     swap(maps[i0], maps[j0]);
 | |
|     swap(maps[i0 + 1], maps[j0 + 1]);
 | |
|     swap(maps[i0 + 2], maps[j0 + 2]);
 | |
|     swap(maps[i0 + 3], maps[j0 + 3]);
 | |
|   }
 | |
| 
 | |
|   // Method to deallocate memory.
 | |
|   py::capsule free_when_done(maps, [](void* mem_) {
 | |
|     DocIdx* mem = reinterpret_cast<DocIdx*>(mem_);
 | |
|     delete[] mem;
 | |
|   });
 | |
| 
 | |
|   // Return the numpy array.
 | |
|   const auto byte_size = sizeof(DocIdx);
 | |
|   return py::array(std::vector<int64_t>{num_samples, 4},  // shape
 | |
|                    {4 * byte_size, byte_size},  // C-style contiguous strides
 | |
|                    maps,                        // the data pointer
 | |
|                    free_when_done);             // numpy array references
 | |
| }
 | |
| 
 | |
| py::array build_blocks_mapping(
 | |
|     const py::array_t<int64_t>& docs_, const py::array_t<int>& sizes_,
 | |
|     const py::array_t<int>& titles_sizes_, const int num_epochs,
 | |
|     const uint64_t max_num_samples, const int max_seq_length, const int seed,
 | |
|     const bool verbose, const bool use_one_sent_blocks) {
 | |
|   if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
 | |
|     if (verbose) {
 | |
|       cout << "    using uint64 for data mapping..." << endl << std::flush;
 | |
|     }
 | |
|     return build_blocks_mapping_impl<uint64_t>(
 | |
|         docs_, sizes_, titles_sizes_, num_epochs, max_num_samples,
 | |
|         max_seq_length, seed, verbose, use_one_sent_blocks);
 | |
|   } else {
 | |
|     if (verbose) {
 | |
|       cout << "    using uint32 for data mapping..." << endl << std::flush;
 | |
|     }
 | |
|     return build_blocks_mapping_impl<uint32_t>(
 | |
|         docs_, sizes_, titles_sizes_, num_epochs, max_num_samples,
 | |
|         max_seq_length, seed, verbose, use_one_sent_blocks);
 | |
|   }
 | |
| }
 | |
| 
 | |
| PYBIND11_MODULE(helpers, m) {
 | |
|   m.def("build_mapping", &build_mapping);
 | |
|   m.def("build_blocks_mapping", &build_blocks_mapping);
 | |
|   m.def("build_sample_idx", &build_sample_idx);
 | |
|   m.def("build_blending_indices", &build_blending_indices);
 | |
| }
 |