Cover image generated by ChatGPT
Introduction
ArrayRecord is a new file format developed by Google to “achieve a new frontier of I/O efficiency” [1]. It has been positioned [2] as the successor to TFRecord [3] for storing and feeding data in large-scale machine learning pipelines. It is designed to accommodate three primary access patterns: sequential, batch, and random access. It solves a significant issue in the TFRecord format: the lack of completely random data access, while still providing high I/O performance.
TensorFlow (via TensorFlow Datasets [4]) and JAX (via Grain) have first-class support for ArrayRecord. You can also use ArrayRecord with PyTorch. It might be tempting to use TensorFlow Datasets to serve ArrayRecord data to PyTorch, but this extra dependency and complexity can be completely avoided. You only need the array-record Python package to achieve good performance. I’ll show you how in this blog post.
In short, there are three key insights you need to build an efficient ArrayRecord pipeline:
- Use
group_size = 1for training data andgroup_size >> 1for validation and test data. - Use PyTorch
Dataset’s__getitems__method with ArrayRecord’s batch operations to obtain a fully randomized training data stream with low I/O and OS overhead. - Use PyTorch’s
IterableDatasetto serve validation and test data using ArrayRecord’s sequential operations for maximum I/O throughput.
My motivation for writing this post is the lack of resources for this specific use case (ArrayRecord + PyTorch). I figured out a way to do it with decent performance, so I’d like to share my experience. I’m not an expert in system programming or I/O performance optimization. I’d appreciate it if you gave me some feedback if I get something wrong.
Quick Tour of ArrayRecord
source: [2]
- File Structure: An ArrayRecord file is structured as a sequence of chunks that may be individually compressed. The index chunk at the end of the file allows random access to chunks without reading the entire file.
- Write-time configurations:
- Group size (
group_size): This determines the number of records stored in a chunk. We read one chunk at a time. Therefore, the group size needs to be 1 when you need fully random access. Otherwise, we’ll have to readgroup_sizerecords to get 1 record, wasting time reading the remaininggroup_size - 1records. - Compression: ArrayRecord supports multiple compression algorithms. The larger the chunk size, the better the compression ratio will be. That’s why it is recommended to use a large
group_sizeif you only need sequantial data access (e.g., for validation and test dataset).
- Group size (
- Access patterns:
- Random access: Reading non-contiguous records one at a time. Requires
group_size = 1for best performance. - Sequential access: Reading contiguous records iteratively. A large
group_sizeis recommended. - Batch access: Reading multiple records in a single function call. This is the recommended way to use ArrayRecord. It can provide performance improvements even for non-contiguous records (random reads), thanks to the underlying C++ thread pool.
- Random access: Reading non-contiguous records one at a time. Requires
Please read the ArrayRecord documentation [4] for more details.
Comparing ArrayRecord with TFRecord
| Feature | TFRecord | ArrayRecord |
|---|---|---|
| Underlying format | A sequence of binary tf.train.Example records (protocol buffers) |
Riegeli with chunked records and end-of-file index |
| Random access | No. Accessing record N requires scanning from the beginning of the file. | Yes. Direct index-based lookup in constant time. |
| Global shuffling | Difficult. Achieved via approximations (shuffling filenames + in-memory buffer shuffling). | Native. A sampler can generate randomized indices on the fly and fetch records in any order. |
| Parallel I/O | Achieved by sharding data into many small files and reading them in parallel. | Native. Multiple processes can read different chunks of the same file simultaneously. |
| Compression | gzip, snappy | zstd, brotli, snappy |
| PyTorch integration | Requires wrapping tf.data.TFRecordDataset in an iterable or using a custom loader. |
Works as a standard map-style Dataset. |
| TensorFlow integration | A first-class citizen in tf.data. |
Supported via tfds.data_source() when the dataset is stored as ArrayRecord. |
Please read this Google for Developers blog post [2] for more details.
Case Study: Image Classification
We’ll use a hypothetical scenario for training and validating an image classification model to provide a case study on using ArrayRecord with PyTorch. In this scenario, we are dealing with a mid-sized image classification dataset with 50k training images and 10k validation images. We ignore the test dataset here for simplicity (its treatment is exactly the same as the validation set).
Preparing ArrayRecord Files
For image data (e.g., JPEG, PNG), ArrayRecord’s documentation [5,6] recommends using the file’s original binary form for optimal compression. Therefore, we serialize each example by concatenating label (an integer), size_of_jpeg_bytes (an integer), and jpeg_bytes. We instruct the ArrayRecordWriter instance to use uncompressed chunks because the image bytes are already compressed.
You can tune the number of examples that go into an ArrayRecord file, and the group size for the validation dataset. We use 5000 examples per ArrayRecord file (shard) for both training and validation. The group size for validation may not matter as much here because we are using uncompressed chunks. We use a group size of 500 without any tuning. This produces 10 ArrayRecord files for the training set and 2 files for the validation set.
Below are the core functions for writing ArrayRecord files. I’ll leave it to you to implement the data preprocessing code that fits your use case.
Note: Remember to use a group size of 1 for the training data.
import struct
from pathlib import Path
from array_record.python.array_record_module import ArrayRecordWriter
def encode_record(label: int, jpeg_bytes: bytes) -> bytes:
"""Encode a label and JPEG payload into a length-prefixed binary record.
Args:
label: Integer class label for the image.
jpeg_bytes: Raw JPEG image bytes.
Returns:
A binary blob suitable for ArrayRecord writing.
"""
return struct.pack("<I", label) + struct.pack("<I", len(jpeg_bytes)) + jpeg_bytes
def _write_array_records(
buffer: list[tuple[int, bytes]],
shard_index: int,
output_path: Path,
group_size: int = 1,
) -> None:
"""Write a buffered batch of records into a single ArrayRecord shard.
Args:
buffer (list[tuple[int, bytes]]): Collection of ``(label, jpeg_bytes)`` tuples to write.
shard_index (int): Zero-based shard index used in the output filename.
output_path (Path): Directory where the shard should be written.
group_size (int): Number of records per compressed chunk. Use ``1`` for
random-access training data and a larger value (e.g. ``500``) for
sequential validation data.
"""
filepath = output_path / "{:03d}-{}.array_record".format(shard_index, len(buffer))
writer = ArrayRecordWriter(
str(filepath),
options=f"group_size:{group_size},uncompressed",
)
for label, jpeg_bytes in buffer:
writer.write(encode_record(label, jpeg_bytes))
writer.close()
Decoding Records
Before we can start implementing the training and validation dataset classes and data loader instances, we need to create a function for decoding a serialized record. Below is a sample implementation. Note that we need to segment the binary data manually. That’s why jpeg_len is added (it probably isn’t strictly required, but it helps us validate the data). The label and jpeg_len integers take four bytes each (they are 32-bit unsigned integers), the rest are the JPEG bytes. The JPEG bytes are decoded using OpenCV and converted into an RGB-formatted NumPy array.
import struct
import cv2
import numpy as np
def decode_record(data: bytes) -> tuple[int, np.ndarray]:
"""Decode a length-prefixed binary record into label and image array.
Args:
data: Raw bytes from an ArrayRecord shard.
Returns:
A tuple of ``(label, image)`` where image is a NumPy array in
HWC / RGB format suitable for ``albumentations`` transforms.
"""
label = struct.unpack("<I", data[:4])[0]
jpeg_len = struct.unpack("<I", data[4:8])[0]
jpeg_bytes = data[8 : 8 + jpeg_len]
arr = np.frombuffer(jpeg_bytes, dtype=np.uint8)
image = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if image is None:
raise ValueError("Failed to decode JPEG bytes")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return label, image
Training Dataset and DataLoader
This is a map-style PyTorch dataset class that supports image augmentations using the Albumentations library [7]. A couple of things to pay attention to:
- The
__getitem__method is implemented for compatibility reasons. However, it should not be used in most cases. That’s why I made it emit a UserWarning every time it is called. If the caller really needs to use it, they can manually suppress the warning message. This is a design choice I made. You can remove the message without any functional change if you disagree. - The
__getitems__method simply calls the method with the same name on the ArrayRecordDataSource instance. Thearray_recordlibrary handles the concurrent I/O for us. - The
_process_recordmethod used by__getitems__is single-threaded, so it may become the true bottleneck. However, we usually use multiple workers in the DataLoader, so other workers can continue to read the records while one worker is decoding the bytes that have been read. - This setup may be tuned to achieve even higher throughput. For example, it may help to use only a couple of workers in the DataLoader and spawn multiple processes in each worker to decode and augment images in parallel. It really depends on your data shapes and your hardware. Do not treat this implementation as a gold standard.
import warnings
from typing import final, override
import torch
import albumentations as A
from array_record.python import array_record_data_source
@final
class ArrayRecordImageDataset(torch.utils.data.Dataset[tuple[torch.Tensor, torch.Tensor]]):
"""Map-style PyTorch Dataset over ArrayRecord shards."""
def __init__(
self,
file_paths: list[str],
transforms: A.Compose | None = None,
):
"""Initialize the dataset.
Args:
file_paths: List of ArrayRecord shard paths.
transforms: Optional albumentations compose pipeline.
"""
self.data_source = array_record_data_source.ArrayRecordDataSource(file_paths)
self.transforms = transforms
def __len__(self) -> int:
return len(self.data_source)
def _process_record(self, record_bytes: bytes) -> tuple[torch.Tensor, torch.Tensor]:
"""Decode and transform a single raw record.
Args:
record_bytes: Raw bytes from an ArrayRecord shard.
Returns:
A tuple of ``(image_tensor, label_tensor)``.
"""
label, image = decode_record(record_bytes)
if self.transforms is not None:
image = self.transforms(image=image)["image"]
image = image.transpose(2, 0, 1)
return torch.from_numpy(image), torch.tensor(label)
@override
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
warnings.warn(
(
"Single-index __getitem__ is slow for batched loading. "
"Use DataLoader with batch_size is not None to trigger __getitems__ "
"for parallel batched I/O."
),
UserWarning,
stacklevel=2,
)
return self._process_record(self.data_source[idx])
def __getitems__(self, indices: list[int]) -> list[tuple[torch.Tensor, torch.Tensor]]:
"""Fetch multiple records in a single batched I/O call.
PyTorch's DataLoader calls this instead of looping over
:meth:`__getitem__` when ``batch_size`` is not None. By delegating to
:meth:`ArrayRecordDataSource.__getitems__`, we enable parallel reads
across multiple ArrayRecord shards.
Args:
indices: List of dataset indices to fetch.
Returns:
A list of ``(image_tensor, label_tensor)`` tuples in the same
order as ``indices``.
"""
records = self.data_source.__getitems__(indices)
return [self._process_record(record) for record in records]
You can pass the ArrayRecordImageDataset instance to the DataLoader class like any other Dataset instance:
dataset = ArrayRecordImageDataset(
train_files,
transforms=train_transforms,
)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True,
)
The Validation Dataset and Data Loader
As mentioned before, an IterableDataset class is a better fit for the validation data pipeline that reads the data sequentially. Below is a sample implementation of such a class that reads ArrayRecord files sequentially.
Pay attention to these points:
- We distribute the workflow across multiple workers (if used) at the file level:
files_to_read = self.file_paths[worker_info.id :: worker_info.num_workers]. This means we should tune the shard size (the number of records in each file) so that the number of shards/files is a multiple of the number of workers. This ensures the workload is evenly distributed for best performance. - The class supports two modes of reading via the
high_memory_modeflag. Whenhigh_memory_modeis True, it uses theArrayRecordReader.read_all()method to read all records in the file at once; whenhigh_memory_modeis False, it uses theArrayRecordReader.read()method to read the records in batches. The former should theoretically be faster. However, the ArrayRecord documentation shows some contradictory empirical results [6], so please always benchmark and tune to find the optimal configurations for your specific use case. - The default
reader_optionsvalue of"readahead_buffer_size:16M,max_parallelism:8"is not tuned. It may make sense to use a different set of options when usingArrayRecordReader.read_all().
from typing import final
import torch
import albumentations as A
from array_record.python.array_record_module import ArrayRecordReader
_DEFAULT_BATCH_READ_SIZE = 500
class SequentialArrayRecordDataset(torch.utils.data.IterableDataset[tuple[torch.Tensor, torch.Tensor]]):
"""Iterable PyTorch Dataset for sequential ArrayRecord reading.
Use this for validation data written with ``group_size > 1``. It bypasses
:class:`~array_record_data_source.ArrayRecordDataSource` (which is
optimized for random access and requires ``group_size:1``) and reads
directly with :class:`~array_record_module.ArrayRecordReader` instead.
"""
def __init__(
self,
file_paths: list[str],
transforms: A.Compose | None = None,
# TODO: figure out the optimal options for this use case
reader_options: str = "readahead_buffer_size:16M,max_parallelism:8",
high_memory_mode: bool = True,
):
"""Initialize the dataset.
Args:
file_paths: List of ArrayRecord shard paths.
transforms: Optional albumentations compose pipeline.
reader_options: Options passed to ``ArrayRecordReader``. Tune this
for sequential throughput on your storage backend.
high_memory_mode: If True, load all records into memory at once
via ``read_all``; if False, read records in batches of
``_DEFAULT_BATCH_READ_SIZE`` sequentially.
"""
self.file_paths: list[str] = [str(p) for p in file_paths]
self.transforms: A.Compose | None = transforms
self.reader_options: str = reader_options
self.high_memory_mode: bool = high_memory_mode
# Pre-compute record counts so __len__ works for progress bars.
self._lengths: list[int] = []
for fp in self.file_paths:
reader = ArrayRecordReader(fp)
self._lengths.append(reader.num_records())
reader.close()
self._total: int = sum(self._lengths)
def __len__(self) -> int:
return self._total
def _process_record(self, record_bytes: bytes) -> tuple[torch.Tensor, torch.Tensor]:
"""Decode and transform a single raw record.
Args:
record_bytes: Raw bytes from an ArrayRecord shard.
Returns:
A tuple of ``(image_tensor, label_tensor)``.
"""
label, image = decode_record(record_bytes)
if self.transforms is not None:
image = self.transforms(image=image)["image"]
image = image.transpose(2, 0, 1)
return torch.from_numpy(image), torch.tensor(label)
@override
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# Single-process loading
files_to_read = self.file_paths
else:
# Multi-worker: shard by file round-robin
files_to_read = self.file_paths[worker_info.id :: worker_info.num_workers]
for fp in files_to_read:
reader = ArrayRecordReader(
fp,
options=self.reader_options,
)
try:
if self.high_memory_mode:
for record in reader.read_all():
yield self._process_record(record)
else:
for idx in range(0, reader.num_records(), _DEFAULT_BATCH_READ_SIZE):
# Note: reader.read() (reading a single record) often hangs the process for unknown reasons
# Solution: Use batch reads for better performance but not read_all to reduce memory footprint
# The actual group size should be a divisor to _DEFAULT_BATCH_READ_SIZE for the best performance
records = reader.read(idx, min(reader.num_records(), idx + _DEFAULT_BATCH_READ_SIZE))
for record in records:
yield self._process_record(record)
finally:
reader.close()
Similar to the map-type Dataset instances, the IterableDataset instance can be directly passed to DataLoader. One caveat is that you need to set the number of workers carefully. Using too many workers can cause the high-throughput sequential access to become essentially low-throughput random access. Experiments show that setting num_workers to 2 works best on my local setup (NVMe SSD + a single GPU). Your mileage may vary.
dataset = SequentialArrayRecordDataset(
valid_files,
transforms=valid_transforms,
high_memory_mode=high_memory_mode,
)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
# Use a fixed number of workers to manually tune the sequential data read operations
# Too many workers create random access patterns, a single worker creates a gap between files
num_workers=2,
pin_memory=True,
drop_last=False,
)
Bonus: Random Access Performance of ArrayRecord vs. Raw Files
For sequential access, reading a few ArrayRecord files is undoubtedly faster than reading hundreds of thousands of small image files. However, this is not as obvious in random access scenarios.
In this section, I’ll provide some references that may offer empirical evidence and theoretical explanations for the claim that ArrayRecord provides better random access than reading a large number of small files. Please note that, based on my research, it’s unclear whether ArrayRecord always performs better. However, we can say that it performs at least on par with the alternative approach of reading a large number of small files and should perform much better in some scenarios.
Per-file Overhead
The article “I built a 2x faster lexer, then discovered I/O was the real bottleneck” [8] shows that reading 104,000 individual files is 42.85 times slower than reading them in tar.gz archives. The archives are 6.68 times smaller, so the actual speedup in I/O is about 6 times.
This mainly reinforces our already established belief that reading a small number of files is much faster than reading a large number of files of the same total size. Nonetheless, the article demonstrates that reading small files will trigger more syscalls than reading large files, which contributes to the overall overhead.
This is the article’s explanation of the I/O time gap (which is only part of the story, as its Addendum section points out):
Additionally, reading 1,351 files sequentially is far more cache-friendly than reading 104,000 files scattered across the filesystem. The OS can prefetch effectively, the SSD can batch operations, and the page cache stays warm.
Remote Storage Per-file Latency
On a network filesystem, every filesystem operation requires a network round-trip. Unlike object stores, where the issue is per-object GET latency, NFS adds round-trip overhead to each operation individually: stat(), open(), read(), and close(). The “Amazon EFS performance tips” article [9] indicates that:
File open, close, and metadata operations generally cannot be made asynchronously or through a pipeline. When reading or writing small files, the two additional round trips are significant. Each round trip (file open, file close) can take as much time as reading or writing megabytes of bulk data.
The “MinIO — The Small Files Problem” article [10] also points out that:
Querying many small files incurs overhead to read metadata, conduct a non-contiguous disk seek, open the file, close the file and repeat. The overhead is only a few milliseconds per file, but when you’re querying thousands, millions or even billions of files, those milliseconds add up.
Back-of-the-envelope for the example 50k-image training dataset on NFS:
| Metric | Raw files (50K) | ArrayRecord (10 files) |
|---|---|---|
| File opens per epoch | 50,000 | 10 |
| Total round-trips (stat + open + read + close) | 200,000 | 40 |
| Time @ 1ms per round-trip | 200 seconds | 0.040 seconds |
| Time @ 5ms per round-trip | 1,000 seconds | 0.200 seconds |
Why ArrayRecord doesn’t suffer the same way: Once the ArrayRecord file is opened (10 open() calls total), reading a record at a random offset is a single read() operation — one round-trip — because the internal index maps each record ID to its byte offset directly. There’s no per-record stat()/open()/close() sequence.
Random Access Benchmark
The official Performance Guide document [6] provides a benchmark for random access that indicates 40- to 100-fold speedup from individual read operations to batch read operations, depending on the compression algorithm used.
Note that the speedup can be attributed to the internal C++ thread pool employed by the array_record package. However, even if we want to match that using multi-threading (which is restricted by the GIL) or multi-processing (higher overhead), it is still likely to be much slower than the high-performance, low-overhead C++ implementation.
| Method | Throughput (QPS) |
|---|---|
Individual read() |
~5,000 |
Batch read(indices) |
~200,000 – ~500,000 |
This is not a direct benchmark comparing random reading of Array Record files and random reading of small image files, but it provides a strong reason to believe that the former will be faster than the latter.
Conclusion
Ultimately, the actual speedup largely depends on the characteristics of your data (e.g., the size of each record and its compression ratios) and your hardware.
The main advantage of using ArrayRecord for machine learning is its robustness against hardware changes. You can expect the same pipeline that works well locally will also work in a cloud environment. There’s no need for environment-specific optimization other than tuning some hyperparameters (e.g., group size, number of workers).
Another advantage of ArrayRecord is its portability. No more archiving thousands or even millions of small files into archives. With ArrayRecord, you just need to share a handful of files. The label data is also embedded. There’s no need for shipping label files separately, and there’s no risk of discrepancy between the label files and the data files.
I’m pretty satisfied with this setup for using ArrayRecord with PyTorch and plan to apply it to different tasks in the near future. I’d appreciate it if you could share your experience with ArrayRecord in the comments section or via social media.
References
- (GitHub) google/array_record
- (Google for Developers) Building High-Performance Data Pipelines with Grain and ArrayRecord
- (Tensorflow) TFRecord and tf.train.Example
- (Tensorflow Datasets) TFDS for Jax and PyTorch
- ArrayRecord: Core Concepts
- ArrayRecord: Performance Guide
- Albumentations: fast and flexible image augmentations
- I built a 2x faster lexer, then discovered I/O was the real bottleneck
- Amazon EFS performance tips
- MinIO — The Small Files Problem
AI Use Disclosure
I used AI tools to conduct research, produce Markdown tables, and revise my writing (primarily for grammar and lexical correctness) in this post. However, I wrote most of the content myself; it was not generated with prompts.