Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 230 additions & 66 deletions nemo_curator/utils/merge_file_prefixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
"""

import argparse
import gc
import os
import shutil
import struct
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from types import TracebackType

import numpy as np
Expand Down Expand Up @@ -76,6 +75,79 @@
return sequence_lengths, document_indices, dtype


def _kernel_copy(src_fd: int, dst_fd: int, dst_offset: int, size: int) -> None:

Check failure on line 78 in nemo_curator/utils/merge_file_prefixes.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR0912)

nemo_curator/utils/merge_file_prefixes.py:78:5: PLR0912 Too many branches (14 > 12)

Check failure on line 78 in nemo_curator/utils/merge_file_prefixes.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (C901)

nemo_curator/utils/merge_file_prefixes.py:78:5: C901 `_kernel_copy` is too complex (15 > 10)

Check failure on line 78 in nemo_curator/utils/merge_file_prefixes.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR0912)

nemo_curator/utils/merge_file_prefixes.py:78:5: PLR0912 Too many branches (14 > 12)

Check failure on line 78 in nemo_curator/utils/merge_file_prefixes.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (C901)

nemo_curator/utils/merge_file_prefixes.py:78:5: C901 `_kernel_copy` is too complex (15 > 10)
"""Copy ``size`` bytes from ``src_fd[0:size]`` into ``dst_fd[dst_offset:dst_offset+size]``
using the fastest available primitive.

Prefers ``os.copy_file_range`` (Linux 4.5+): explicit src/dst offsets, GIL released,
in-kernel copy that becomes a reflink on filesystems that support it (XFS, Btrfs).
Falls back to ``os.sendfile`` (writes at dst's *current* file position — callers using
sendfile in parallel must give each worker its own dst fd) and finally to a chunked
read/write loop.
"""
if size == 0:
return

if hasattr(os, "copy_file_range"):
try:
sent = 0
while sent < size:
n = os.copy_file_range(
src_fd,
dst_fd,
size - sent,
offset_src=sent,
offset_dst=dst_offset + sent,
)
if n == 0:
break
sent += n
if sent == size:
return
# Short copy (some backends return 0 prematurely) - fall through
except OSError:
pass

if hasattr(os, "sendfile"):
try:
os.lseek(dst_fd, dst_offset, os.SEEK_SET)
sent = 0
while sent < size:
n = os.sendfile(dst_fd, src_fd, sent, size - sent)
if n == 0:
break
sent += n
if sent == size:
return
# Short copy - fall through to the chunked read/write path
except OSError:
pass

os.lseek(src_fd, 0, os.SEEK_SET)
os.lseek(dst_fd, dst_offset, os.SEEK_SET)
chunk_size = 64 * 1024 * 1024
remaining = size
while remaining > 0:
buf = os.read(src_fd, min(chunk_size, remaining))
if not buf:
break
os.write(dst_fd, buf)
remaining -= len(buf)
Comment on lines +91 to +135
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Silent data truncation on unexpected zero-byte return. All three copy branches treat n == 0 as "done" but never verify that sent == size before returning. If copy_file_range or sendfile returns 0 prematurely (e.g., the source file shrinks between the os.path.getsize call and the copy, or a FUSE/overlay filesystem emits an unexpected early EOF), the output .bin is silently left with zeroed regions rather than the actual token data. The index will still point to those positions, so the merged dataset will silently contain garbage tokens without any error.

Suggested change
if hasattr(os, "copy_file_range"):
try:
sent = 0
while sent < size:
n = os.copy_file_range(
src_fd,
dst_fd,
size - sent,
offset_src=sent,
offset_dst=dst_offset + sent,
)
if n == 0:
break
sent += n
return
except OSError:
pass
if hasattr(os, "sendfile"):
try:
os.lseek(dst_fd, dst_offset, os.SEEK_SET)
sent = 0
while sent < size:
n = os.sendfile(dst_fd, src_fd, sent, size - sent)
if n == 0:
break
sent += n
return
except OSError:
pass
os.lseek(src_fd, 0, os.SEEK_SET)
os.lseek(dst_fd, dst_offset, os.SEEK_SET)
chunk_size = 64 * 1024 * 1024
remaining = size
while remaining > 0:
buf = os.read(src_fd, min(chunk_size, remaining))
if not buf:
break
os.write(dst_fd, buf)
remaining -= len(buf)
if hasattr(os, "copy_file_range"):
try:
sent = 0
while sent < size:
n = os.copy_file_range(
src_fd,
dst_fd,
size - sent,
offset_src=sent,
offset_dst=dst_offset + sent,
)
if n == 0:
break
sent += n
if sent == size:
return
except OSError:
pass
if hasattr(os, "sendfile"):
try:
os.lseek(dst_fd, dst_offset, os.SEEK_SET)
sent = 0
while sent < size:
n = os.sendfile(dst_fd, src_fd, sent, size - sent)
if n == 0:
break
sent += n
if sent == size:
return
except OSError:
pass
os.lseek(src_fd, 0, os.SEEK_SET)
os.lseek(dst_fd, dst_offset, os.SEEK_SET)
chunk_size = 64 * 1024 * 1024
remaining = size
while remaining > 0:
buf = os.read(src_fd, min(chunk_size, remaining))
if not buf:
break
os.write(dst_fd, buf)
remaining -= len(buf)
if remaining != 0:
raise OSError(f"Incomplete copy: {size - remaining} of {size} bytes written")

if remaining > 0:
msg = f"premature EOF: copied {size - remaining} of {size} bytes"
raise OSError(msg)


def _copy_bin_at_offset(src_path: str, dst_path: str, dst_offset: int, size: int) -> None:
"""Copy a full input ``.bin`` file into ``[dst_offset, dst_offset+size)`` of the output.

Each call opens its own src and dst fds, so this is safe to invoke concurrently from
threads — no shared file-position state.
"""
with open(src_path, "rb") as src, open(dst_path, "r+b") as dst:
_kernel_copy(src.fileno(), dst.fileno(), dst_offset, size)


class _IndexWriter:
"""Simplified version of the _IndexWriter class from the Megatron-LM library.

Expand Down Expand Up @@ -150,30 +222,35 @@
self.idx_writer.write(struct.pack("<Q", document_count))

# the number of tokens per sequence
self.idx_writer.write(np.array(sequence_lengths, dtype=np.int32).tobytes(order="C"))
self.idx_writer.write(np.ascontiguousarray(sequence_lengths, dtype=np.int32).tobytes(order="C"))

# the byte offsets for all sequences
self.idx_writer.write(np.array(sequence_pointers, dtype=np.int64).tobytes(order="C"))
self.idx_writer.write(sequence_pointers.tobytes(order="C"))

# the sequence indices marking the end of each document
self.idx_writer.write(np.array(document_indices, dtype=np.int64).tobytes(order="C"))
self.idx_writer.write(np.ascontiguousarray(document_indices, dtype=np.int64).tobytes(order="C"))

def _sequence_pointers(self, sequence_lengths: Iterable[int | np.integer]) -> list[int]:
def _sequence_pointers(self, sequence_lengths: Iterable[int | np.integer]) -> np.ndarray:
"""Build the sequence pointers per the sequence lengths and dtype size

Args:
sequence_lengths (List[int]): The length of each sequence

Returns:
List[int]: The pointer to the beginning of each sequence
np.ndarray: int64 array of byte pointers to the start of each sequence in the bin file
"""
itemsize = np.int64(4 if self.dtype == np.int32 else 2)
curr_ptr = np.int64(0)
list_ptr = []
for length in sequence_lengths:
list_ptr.append(curr_ptr.item())
curr_ptr += length * itemsize
return list_ptr
itemsize = 4 if self.dtype == np.int32 else 2
n = len(sequence_lengths)
pointers = np.empty(n, dtype=np.int64)
if n == 0:
return pointers
pointers[0] = 0
if n > 1:
np.cumsum(
np.asarray(sequence_lengths[:-1], dtype=np.int64) * np.int64(itemsize),
out=pointers[1:],
)
return pointers


class IndexedDatasetBuilder:
Expand All @@ -192,30 +269,36 @@
self.data_file = open(bin_path, "wb") # noqa: SIM115
self.dtype = dtype

self.sequence_lengths = []
self.document_indices = [0]
# Accumulate per-input arrays and concatenate once at finalize. Extending a Python
# list with numpy scalars and re-boxing them at the end was a major hot spot.
self._seq_chunks: list[np.ndarray] = []
self._doc_chunks: list[np.ndarray] = [np.array([0], dtype=np.int64)]
self._cumulative_seq = 0

def add_index(self, path_prefix: str) -> None:
"""Add an entire IndexedDataset to the dataset

Args:
path_prefix (str): The index (.idx) and data (.bin) prefix
"""
# Concatenate index
sequence_lengths, document_indices, dtype = extract_index_contents(path_prefix + ".idx")
assert dtype == self.dtype # noqa: S101

offset = len(self.sequence_lengths)
self.sequence_lengths.extend(sequence_lengths)
self.document_indices.extend((offset + document_indices)[1:])

# Free up memory to make space for new indices
del sequence_lengths, document_indices
gc.collect()

# Concatenate data
with open(path_prefix + ".bin", "rb") as f:
shutil.copyfileobj(f, self.data_file)
offset = self._cumulative_seq
self._seq_chunks.append(sequence_lengths)
self._doc_chunks.append(np.int64(offset) + document_indices[1:].astype(np.int64, copy=False))
self._cumulative_seq += len(sequence_lengths)

src_path = path_prefix + ".bin"
size = os.path.getsize(src_path)
if size:
self.data_file.flush()
dst_offset = self.data_file.tell()
with open(src_path, "rb") as src:
_kernel_copy(src.fileno(), self.data_file.fileno(), dst_offset, size)
# Sync Python's logical position with the kernel-level file size so any
# subsequent buffered ops see the right offset.
self.data_file.seek(dst_offset + size)

def finalize(self, idx_path: str) -> None:
"""Clean up and write the index (.idx) file
Expand All @@ -224,8 +307,113 @@
idx_path (str): The path to the index file
"""
self.data_file.close()
sequence_lengths = (
np.concatenate(self._seq_chunks) if self._seq_chunks else np.empty(0, dtype=np.int32)
)
document_indices = np.concatenate(self._doc_chunks)
with _IndexWriter(idx_path, self.dtype) as writer:
writer.write(self.sequence_lengths, self.document_indices)
writer.write(sequence_lengths, document_indices)


def _discover_prefixes(input_dir: str) -> list[str]:
prefixes: set[str] = set()
for basename in os.listdir(input_dir):
prefix, ext = os.path.splitext(basename)

if ext not in {".bin", ".idx"}:
continue

if prefix in prefixes:
continue

if not os.path.isfile(os.path.join(input_dir, basename)):
continue

ext_pair = ".bin" if ext == ".idx" else ".idx"
assert os.path.isfile(os.path.join(input_dir, prefix + ext_pair)), ( # noqa: S101
f"ERROR: {ext_pair} file not provided for {os.path.join(input_dir, prefix)}"
)

prefixes.add(prefix)

if not prefixes:
msg = f"ERROR: No valid file prefix pairs found in {input_dir}"
raise ValueError(msg)

return sorted(prefixes)


def merge_file_prefixes(input_dir: str, output_prefix: str, workers: int = 1) -> None:
"""Merge all .bin/.idx prefix pairs in ``input_dir`` into a single pair at ``output_prefix``.

Args:
input_dir: Directory containing the .bin/.idx pairs to merge.
output_prefix: Output path prefix; produces ``<output_prefix>.bin`` and ``<output_prefix>.idx``.
workers: Threads for parallel index reads and bin copies. Default 1 is serial. Bin copies
use kernel zero-copy (``copy_file_range``), so ``workers > 1`` parallelizes those across
threads — useful on parallel/multi-stream storage. A single local SSD typically saturates
at 2-4 workers.
"""
prefixes = _discover_prefixes(input_dir)

if workers <= 1:
builder: IndexedDatasetBuilder | None = None
for prefix in prefixes:
if builder is None:
_, _, dtype = extract_index_contents(os.path.join(input_dir, prefix + ".idx"))
builder = IndexedDatasetBuilder(output_prefix + ".bin", dtype=dtype)
builder.add_index(os.path.join(input_dir, prefix))
builder.finalize(output_prefix + ".idx")
return

paths = [os.path.join(input_dir, p) for p in prefixes]

# Phase 1: read all idx files concurrently. Each thread independently mmaps its file,
# so no shared state.
with ThreadPoolExecutor(max_workers=workers) as ex:
idx_results = list(ex.map(lambda p: extract_index_contents(p + ".idx"), paths))

dtype = idx_results[0][2]
for _, _, d in idx_results[1:]:
assert d == dtype, f"dtype mismatch across input files: {d} vs {dtype}" # noqa: S101

seq_chunks = [r[0] for r in idx_results]
doc_chunks: list[np.ndarray] = [np.array([0], dtype=np.int64)]
cumulative_seq = 0
for seq, doc, _ in idx_results:
doc_chunks.append(np.int64(cumulative_seq) + doc[1:].astype(np.int64, copy=False))
cumulative_seq += len(seq)
sequence_lengths = np.concatenate(seq_chunks) if seq_chunks else np.empty(0, dtype=np.int32)
document_indices = np.concatenate(doc_chunks)

# Phase 2: pre-size the output bin and copy each input slice into its known offset
# in parallel.
bin_paths = [p + ".bin" for p in paths]
bin_sizes = [os.path.getsize(p) for p in bin_paths]
offsets: list[int] = []
running = 0
for s in bin_sizes:
offsets.append(running)
running += s
total_size = running

out_bin = output_prefix + ".bin"
with open(out_bin, "wb") as f:
if total_size:
os.ftruncate(f.fileno(), total_size)

if total_size:
with ThreadPoolExecutor(max_workers=workers) as ex:
list(
ex.map(
lambda args: _copy_bin_at_offset(args[0], out_bin, args[1], args[2]),
zip(bin_paths, offsets, bin_sizes, strict=True),
)
)

# Phase 3: write the merged idx.
with _IndexWriter(output_prefix + ".idx", dtype=dtype) as writer:
writer.write(sequence_lengths, document_indices)


def get_args() -> argparse.Namespace:
Expand All @@ -247,6 +435,18 @@
help="Path to merged output file prefix",
)

group = parser.add_argument_group(title="parallelism")
group.add_argument(
"--workers",
type=int,
default=1,
help=(
"Threads for parallel index read and bin copy. Default 1 (serial). On a single "
"local SSD, 2-4 typically saturates I/O; on parallel/multi-stream storage "
"(Lustre, multi-mount NFS), more pays off."
),
)

args = parser.parse_args()

assert os.path.isdir(args.input_dir), f"ERROR: {args.input_dir} is not a directory or does not exist" # noqa: S101
Comment thread
sarahyurick marked this conversation as resolved.
Expand All @@ -259,42 +459,6 @@
return args


def merge_file_prefixes(input_dir: str, output_prefix: str) -> None:
prefixes = set()
for basename in os.listdir(input_dir):
prefix, ext = os.path.splitext(basename)

if ext not in {".bin", ".idx"}:
continue

if prefix in prefixes:
continue

if not os.path.isfile(os.path.join(input_dir, basename)):
continue

ext_pair = ".bin" if ext == ".idx" else ".idx"
assert os.path.isfile(os.path.join(input_dir, prefix + ext_pair)), ( # noqa: S101
f"ERROR: {ext_pair} file not provided for {os.path.join(input_dir, prefix)}"
)

prefixes.add(prefix)

if not prefixes:
msg = f"ERROR: No valid file prefix pairs found in {input_dir}"
raise ValueError(msg)

builder = None
for prefix in sorted(prefixes):
if builder is None:
_, _, dtype = extract_index_contents(os.path.join(input_dir, prefix + ".idx"))
builder = IndexedDatasetBuilder(output_prefix + ".bin", dtype=dtype)

builder.add_index(os.path.join(input_dir, prefix))

builder.finalize(output_prefix + ".idx")


if __name__ == "__main__":
args = get_args()
merge_file_prefixes(args.input_dir, args.output_prefix)
merge_file_prefixes(args.input_dir, args.output_prefix, workers=args.workers)
Loading