diff --git a/nemo_curator/utils/merge_file_prefixes.py b/nemo_curator/utils/merge_file_prefixes.py index cc072839e6..553e9a01ab 100644 --- a/nemo_curator/utils/merge_file_prefixes.py +++ b/nemo_curator/utils/merge_file_prefixes.py @@ -18,11 +18,10 @@ """ import argparse -import gc import os -import shutil import struct from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor from types import TracebackType import numpy as np @@ -76,6 +75,79 @@ def extract_index_contents(idx_path: str) -> tuple[np.ndarray, np.ndarray, type[ return sequence_lengths, document_indices, dtype +def _kernel_copy(src_fd: int, dst_fd: int, dst_offset: int, size: int) -> None: + """Copy ``size`` bytes from ``src_fd[0:size]`` into ``dst_fd[dst_offset:dst_offset+size]`` + using the fastest available primitive. + + Prefers ``os.copy_file_range`` (Linux 4.5+): explicit src/dst offsets, GIL released, + in-kernel copy that becomes a reflink on filesystems that support it (XFS, Btrfs). + Falls back to ``os.sendfile`` (writes at dst's *current* file position — callers using + sendfile in parallel must give each worker its own dst fd) and finally to a chunked + read/write loop. + """ + if size == 0: + return + + if hasattr(os, "copy_file_range"): + try: + sent = 0 + while sent < size: + n = os.copy_file_range( + src_fd, + dst_fd, + size - sent, + offset_src=sent, + offset_dst=dst_offset + sent, + ) + if n == 0: + break + sent += n + if sent == size: + return + # Short copy (some backends return 0 prematurely) - fall through + except OSError: + pass + + if hasattr(os, "sendfile"): + try: + os.lseek(dst_fd, dst_offset, os.SEEK_SET) + sent = 0 + while sent < size: + n = os.sendfile(dst_fd, src_fd, sent, size - sent) + if n == 0: + break + sent += n + if sent == size: + return + # Short copy - fall through to the chunked read/write path + except OSError: + pass + + os.lseek(src_fd, 0, os.SEEK_SET) + os.lseek(dst_fd, dst_offset, os.SEEK_SET) + chunk_size = 64 * 1024 * 1024 + remaining = size + while remaining > 0: + buf = os.read(src_fd, min(chunk_size, remaining)) + if not buf: + break + os.write(dst_fd, buf) + remaining -= len(buf) + if remaining > 0: + msg = f"premature EOF: copied {size - remaining} of {size} bytes" + raise OSError(msg) + + +def _copy_bin_at_offset(src_path: str, dst_path: str, dst_offset: int, size: int) -> None: + """Copy a full input ``.bin`` file into ``[dst_offset, dst_offset+size)`` of the output. + + Each call opens its own src and dst fds, so this is safe to invoke concurrently from + threads — no shared file-position state. + """ + with open(src_path, "rb") as src, open(dst_path, "r+b") as dst: + _kernel_copy(src.fileno(), dst.fileno(), dst_offset, size) + + class _IndexWriter: """Simplified version of the _IndexWriter class from the Megatron-LM library. @@ -150,30 +222,35 @@ def write( self.idx_writer.write(struct.pack(" list[int]: + def _sequence_pointers(self, sequence_lengths: Iterable[int | np.integer]) -> np.ndarray: """Build the sequence pointers per the sequence lengths and dtype size Args: sequence_lengths (List[int]): The length of each sequence Returns: - List[int]: The pointer to the beginning of each sequence + np.ndarray: int64 array of byte pointers to the start of each sequence in the bin file """ - itemsize = np.int64(4 if self.dtype == np.int32 else 2) - curr_ptr = np.int64(0) - list_ptr = [] - for length in sequence_lengths: - list_ptr.append(curr_ptr.item()) - curr_ptr += length * itemsize - return list_ptr + itemsize = 4 if self.dtype == np.int32 else 2 + n = len(sequence_lengths) + pointers = np.empty(n, dtype=np.int64) + if n == 0: + return pointers + pointers[0] = 0 + if n > 1: + np.cumsum( + np.asarray(sequence_lengths[:-1], dtype=np.int64) * np.int64(itemsize), + out=pointers[1:], + ) + return pointers class IndexedDatasetBuilder: @@ -192,8 +269,11 @@ def __init__(self, bin_path: str, dtype: type[np.number]) -> None: self.data_file = open(bin_path, "wb") # noqa: SIM115 self.dtype = dtype - self.sequence_lengths = [] - self.document_indices = [0] + # Accumulate per-input arrays and concatenate once at finalize. Extending a Python + # list with numpy scalars and re-boxing them at the end was a major hot spot. + self._seq_chunks: list[np.ndarray] = [] + self._doc_chunks: list[np.ndarray] = [np.array([0], dtype=np.int64)] + self._cumulative_seq = 0 def add_index(self, path_prefix: str) -> None: """Add an entire IndexedDataset to the dataset @@ -201,21 +281,24 @@ def add_index(self, path_prefix: str) -> None: Args: path_prefix (str): The index (.idx) and data (.bin) prefix """ - # Concatenate index sequence_lengths, document_indices, dtype = extract_index_contents(path_prefix + ".idx") assert dtype == self.dtype # noqa: S101 - offset = len(self.sequence_lengths) - self.sequence_lengths.extend(sequence_lengths) - self.document_indices.extend((offset + document_indices)[1:]) - - # Free up memory to make space for new indices - del sequence_lengths, document_indices - gc.collect() - - # Concatenate data - with open(path_prefix + ".bin", "rb") as f: - shutil.copyfileobj(f, self.data_file) + offset = self._cumulative_seq + self._seq_chunks.append(sequence_lengths) + self._doc_chunks.append(np.int64(offset) + document_indices[1:].astype(np.int64, copy=False)) + self._cumulative_seq += len(sequence_lengths) + + src_path = path_prefix + ".bin" + size = os.path.getsize(src_path) + if size: + self.data_file.flush() + dst_offset = self.data_file.tell() + with open(src_path, "rb") as src: + _kernel_copy(src.fileno(), self.data_file.fileno(), dst_offset, size) + # Sync Python's logical position with the kernel-level file size so any + # subsequent buffered ops see the right offset. + self.data_file.seek(dst_offset + size) def finalize(self, idx_path: str) -> None: """Clean up and write the index (.idx) file @@ -224,8 +307,113 @@ def finalize(self, idx_path: str) -> None: idx_path (str): The path to the index file """ self.data_file.close() + sequence_lengths = ( + np.concatenate(self._seq_chunks) if self._seq_chunks else np.empty(0, dtype=np.int32) + ) + document_indices = np.concatenate(self._doc_chunks) with _IndexWriter(idx_path, self.dtype) as writer: - writer.write(self.sequence_lengths, self.document_indices) + writer.write(sequence_lengths, document_indices) + + +def _discover_prefixes(input_dir: str) -> list[str]: + prefixes: set[str] = set() + for basename in os.listdir(input_dir): + prefix, ext = os.path.splitext(basename) + + if ext not in {".bin", ".idx"}: + continue + + if prefix in prefixes: + continue + + if not os.path.isfile(os.path.join(input_dir, basename)): + continue + + ext_pair = ".bin" if ext == ".idx" else ".idx" + assert os.path.isfile(os.path.join(input_dir, prefix + ext_pair)), ( # noqa: S101 + f"ERROR: {ext_pair} file not provided for {os.path.join(input_dir, prefix)}" + ) + + prefixes.add(prefix) + + if not prefixes: + msg = f"ERROR: No valid file prefix pairs found in {input_dir}" + raise ValueError(msg) + + return sorted(prefixes) + + +def merge_file_prefixes(input_dir: str, output_prefix: str, workers: int = 1) -> None: + """Merge all .bin/.idx prefix pairs in ``input_dir`` into a single pair at ``output_prefix``. + + Args: + input_dir: Directory containing the .bin/.idx pairs to merge. + output_prefix: Output path prefix; produces ``.bin`` and ``.idx``. + workers: Threads for parallel index reads and bin copies. Default 1 is serial. Bin copies + use kernel zero-copy (``copy_file_range``), so ``workers > 1`` parallelizes those across + threads — useful on parallel/multi-stream storage. A single local SSD typically saturates + at 2-4 workers. + """ + prefixes = _discover_prefixes(input_dir) + + if workers <= 1: + builder: IndexedDatasetBuilder | None = None + for prefix in prefixes: + if builder is None: + _, _, dtype = extract_index_contents(os.path.join(input_dir, prefix + ".idx")) + builder = IndexedDatasetBuilder(output_prefix + ".bin", dtype=dtype) + builder.add_index(os.path.join(input_dir, prefix)) + builder.finalize(output_prefix + ".idx") + return + + paths = [os.path.join(input_dir, p) for p in prefixes] + + # Phase 1: read all idx files concurrently. Each thread independently mmaps its file, + # so no shared state. + with ThreadPoolExecutor(max_workers=workers) as ex: + idx_results = list(ex.map(lambda p: extract_index_contents(p + ".idx"), paths)) + + dtype = idx_results[0][2] + for _, _, d in idx_results[1:]: + assert d == dtype, f"dtype mismatch across input files: {d} vs {dtype}" # noqa: S101 + + seq_chunks = [r[0] for r in idx_results] + doc_chunks: list[np.ndarray] = [np.array([0], dtype=np.int64)] + cumulative_seq = 0 + for seq, doc, _ in idx_results: + doc_chunks.append(np.int64(cumulative_seq) + doc[1:].astype(np.int64, copy=False)) + cumulative_seq += len(seq) + sequence_lengths = np.concatenate(seq_chunks) if seq_chunks else np.empty(0, dtype=np.int32) + document_indices = np.concatenate(doc_chunks) + + # Phase 2: pre-size the output bin and copy each input slice into its known offset + # in parallel. + bin_paths = [p + ".bin" for p in paths] + bin_sizes = [os.path.getsize(p) for p in bin_paths] + offsets: list[int] = [] + running = 0 + for s in bin_sizes: + offsets.append(running) + running += s + total_size = running + + out_bin = output_prefix + ".bin" + with open(out_bin, "wb") as f: + if total_size: + os.ftruncate(f.fileno(), total_size) + + if total_size: + with ThreadPoolExecutor(max_workers=workers) as ex: + list( + ex.map( + lambda args: _copy_bin_at_offset(args[0], out_bin, args[1], args[2]), + zip(bin_paths, offsets, bin_sizes, strict=True), + ) + ) + + # Phase 3: write the merged idx. + with _IndexWriter(output_prefix + ".idx", dtype=dtype) as writer: + writer.write(sequence_lengths, document_indices) def get_args() -> argparse.Namespace: @@ -247,6 +435,18 @@ def get_args() -> argparse.Namespace: help="Path to merged output file prefix", ) + group = parser.add_argument_group(title="parallelism") + group.add_argument( + "--workers", + type=int, + default=1, + help=( + "Threads for parallel index read and bin copy. Default 1 (serial). On a single " + "local SSD, 2-4 typically saturates I/O; on parallel/multi-stream storage " + "(Lustre, multi-mount NFS), more pays off." + ), + ) + args = parser.parse_args() assert os.path.isdir(args.input_dir), f"ERROR: {args.input_dir} is not a directory or does not exist" # noqa: S101 @@ -259,42 +459,6 @@ def get_args() -> argparse.Namespace: return args -def merge_file_prefixes(input_dir: str, output_prefix: str) -> None: - prefixes = set() - for basename in os.listdir(input_dir): - prefix, ext = os.path.splitext(basename) - - if ext not in {".bin", ".idx"}: - continue - - if prefix in prefixes: - continue - - if not os.path.isfile(os.path.join(input_dir, basename)): - continue - - ext_pair = ".bin" if ext == ".idx" else ".idx" - assert os.path.isfile(os.path.join(input_dir, prefix + ext_pair)), ( # noqa: S101 - f"ERROR: {ext_pair} file not provided for {os.path.join(input_dir, prefix)}" - ) - - prefixes.add(prefix) - - if not prefixes: - msg = f"ERROR: No valid file prefix pairs found in {input_dir}" - raise ValueError(msg) - - builder = None - for prefix in sorted(prefixes): - if builder is None: - _, _, dtype = extract_index_contents(os.path.join(input_dir, prefix + ".idx")) - builder = IndexedDatasetBuilder(output_prefix + ".bin", dtype=dtype) - - builder.add_index(os.path.join(input_dir, prefix)) - - builder.finalize(output_prefix + ".idx") - - if __name__ == "__main__": args = get_args() - merge_file_prefixes(args.input_dir, args.output_prefix) + merge_file_prefixes(args.input_dir, args.output_prefix, workers=args.workers)