Skip to content

Commit 1f5fa59

Browse files
PAenugulacopybara-github
authored andcommitted
Add ElasticIterDatasetIterator to handle scaling up and down between checkpoints.
* Allows users to keep their pipelines elastic and restore from a checkpoint with variable amount of shards * Dataset and Iterator class in one to allow changing sharding configuration * Add dedicated checkpoint handler for saving/restoring from Orbax PiperOrigin-RevId: 864910611
1 parent d47a7f0 commit 1f5fa59

9 files changed

Lines changed: 691 additions & 29 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
1212
and advance a `grain.DatasetIterator` to the given produced element index.
1313
* Switches to multithreading instead of multiprocessing in
1414
`IterDataset.mp_prefetch` when free-threaded Python is detected.
15+
* Add `ElasticIterDatasetIterator` for scaling up and down the number of shards between checkpoints.
1516

1617
* Breaking changes:
1718
* Custom implementations of `RandomAccessDataSource` should accept `int`

grain/_src/python/checkpoint/BUILD

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,34 @@ py_library(
1717
srcs = ["handler.py"],
1818
srcs_version = "PY3",
1919
deps = [
20+
":elastic_checkpoint",
2021
"//grain/_src/core:sharding",
2122
"//grain/_src/python:data_loader",
2223
"//grain/_src/python/dataset",
24+
"//grain/_src/python/dataset:elastic_iterator",
25+
"@pypi//etils:pkg",
26+
],
27+
)
28+
29+
py_library(
30+
name = "elastic_checkpoint",
31+
srcs = ["elastic_checkpoint.py"],
32+
srcs_version = "PY3",
33+
deps = [
34+
"//grain/_src/python/dataset:elastic_iterator",
35+
"@pypi//etils:pkg",
36+
],
37+
)
38+
39+
py_test(
40+
name = "elastic_checkpoint_test",
41+
srcs = ["elastic_checkpoint_test.py"],
42+
srcs_version = "PY3",
43+
deps = [
44+
":elastic_checkpoint",
45+
"//grain/_src/core:sharding",
46+
"//grain/_src/python/dataset:elastic_iterator",
47+
"@abseil-py//absl/testing:absltest",
2348
"@pypi//etils:pkg",
2449
],
2550
)
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""This module provides checkpointing logic for ElasticIterDatasetIterator."""
2+
3+
import dataclasses
4+
import json
5+
from typing import Any, Optional, Sequence
6+
7+
from etils import epath
8+
from grain._src.python.dataset import elastic_iterator
9+
10+
11+
def _find_shard_file(
12+
directory: epath.Path,
13+
shard_index: int,
14+
total_num_shards: int,
15+
) -> epath.Path:
16+
"""Finds all files matching 'shard_state_*.json' in the directory."""
17+
all_files = list(directory.iterdir())
18+
pattern = f"shard_state_{shard_index}-of-{total_num_shards}.json"
19+
found_files = [f for f in all_files if f.name.endswith(pattern)]
20+
if not found_files:
21+
raise ValueError(
22+
f"No shard state files found in {directory} for shard {shard_index}"
23+
)
24+
if len(found_files) > 1:
25+
raise ValueError(
26+
f"Multiple shard state files found in {directory} for shard"
27+
f" {shard_index}"
28+
)
29+
return found_files[0]
30+
31+
32+
def save_elastic_iterator(
33+
directory: epath.Path,
34+
item: elastic_iterator.ElasticIterDatasetIterator,
35+
):
36+
"""Saves the given iterator to the checkpoint in `directory`."""
37+
state = item.get_state()
38+
ds_iterator_states = state["ds_iterator_states"]
39+
total_num_shards = state["total_num_shards"]
40+
for idx, host_iterator_state in ds_iterator_states.items():
41+
host_iterator_state["total_num_shards"] = total_num_shards
42+
shard_state = json.dumps(host_iterator_state, indent=4)
43+
filename = directory / f"shard_state_{idx}-of-{total_num_shards}.json"
44+
filename.write_text(shard_state)
45+
46+
47+
def restore_elastic_iterator(
48+
directory: epath.Path,
49+
item: elastic_iterator.ElasticIterDatasetIterator,
50+
):
51+
"""Restores the given iterator from the checkpoint in `directory`."""
52+
total_num_shards = item.total_num_shards
53+
shard_index = item.shard_options.shard_index
54+
shard_count = item.shard_options.shard_count
55+
while shard_index < total_num_shards:
56+
filename = _find_shard_file(directory, shard_index, total_num_shards)
57+
state = filename.read_text()
58+
state = json.loads(state)
59+
item.update_shard_iterator_state(shard_index, state)
60+
shard_index += shard_count
61+
62+
63+
class ElasticCheckpointHandler:
64+
"""Orbax CheckpointHandler for PyGrain iterators."""
65+
66+
def save(
67+
self,
68+
directory: epath.Path,
69+
item: Optional[
70+
elastic_iterator.ElasticIterDatasetIterator
71+
| Sequence[elastic_iterator.ElasticIterDatasetIterator]
72+
] = None,
73+
args: Any = None,
74+
):
75+
"""Saves the given iterator to the checkpoint in `directory`."""
76+
item = item or args.item
77+
if isinstance(item, elastic_iterator.ElasticIterDatasetIterator):
78+
item = [item]
79+
for iterator in item:
80+
save_elastic_iterator(directory, iterator)
81+
82+
def restore(
83+
self,
84+
directory: epath.Path,
85+
item: Optional[
86+
elastic_iterator.ElasticIterDatasetIterator
87+
| Sequence[elastic_iterator.ElasticIterDatasetIterator]
88+
] = None,
89+
args: Any = None,
90+
) -> Any:
91+
"""Restores the given iterator from the checkpoint in `directory`."""
92+
item = item or args.item
93+
if isinstance(item, elastic_iterator.ElasticIterDatasetIterator):
94+
item = [item]
95+
for iterator in item:
96+
restore_elastic_iterator(directory, iterator)
97+
return item
98+
99+
# Required by interface but not supported by PyGrain checkpoints.
100+
def structure(self, directory: epath.Path) -> Any:
101+
del directory
102+
return None
103+
104+
# Required by interface.
105+
106+
def metadata(self, directory: epath.Path) -> Optional[Any]:
107+
del directory
108+
return None
109+
110+
def finalize(self, directory: epath.Path):
111+
pass
112+
113+
def close(self):
114+
pass
115+
116+
@classmethod
117+
def typestr(cls):
118+
return f"{cls.__module__}.{cls.__qualname__}"
119+
120+
121+
try:
122+
# Register the handler to be used with the new checkpointing API if Orbax is
123+
# present.
124+
import orbax.checkpoint as ocp # pylint:disable=g-import-not-at-top # pytype:disable=import-error
125+
126+
@ocp.args.register_with_handler(ElasticCheckpointHandler, for_save=True) # pytype:disable=wrong-arg-types
127+
@dataclasses.dataclass
128+
class ElasticCheckpointSave(ocp.args.CheckpointArgs):
129+
item: Any
130+
131+
@ocp.args.register_with_handler(ElasticCheckpointHandler, for_restore=True) # pytype:disable=wrong-arg-types
132+
@dataclasses.dataclass
133+
class ElasticCheckpointRestore(ocp.args.CheckpointArgs):
134+
item: Any
135+
136+
except (ImportError, TypeError, AttributeError):
137+
pass
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""Tests for elastic checkpoint."""
2+
3+
import json
4+
5+
from etils import epath
6+
from grain._src.core import sharding
7+
from grain._src.python.checkpoint import elastic_checkpoint
8+
from grain._src.python.dataset import elastic_iterator
9+
10+
from absl.testing import absltest
11+
12+
13+
class MockElasticIterDatasetIterator(
14+
elastic_iterator.ElasticIterDatasetIterator
15+
):
16+
17+
def __init__(self, shard_options, total_num_shards, states=None):
18+
self._shard_options = shard_options
19+
self._total_num_shards = total_num_shards
20+
self._states = states if states is not None else {}
21+
self.updated_states = {}
22+
23+
def get_state(self):
24+
return {
25+
"ds_iterator_states": self._states,
26+
"total_num_shards": self._total_num_shards,
27+
}
28+
29+
def update_shard_iterator_state(self, shard_index, state):
30+
self.updated_states[shard_index] = state
31+
32+
33+
class ElasticCheckpointTest(absltest.TestCase):
34+
35+
def test_save_and_restore_elastic_iterator(self):
36+
temp_dir = epath.Path(self.create_tempdir().full_path)
37+
shard_options = sharding.ShardOptions(shard_index=0, shard_count=1)
38+
states = {
39+
0: {"val": 0},
40+
1: {"val": 1},
41+
}
42+
iterator = MockElasticIterDatasetIterator(
43+
shard_options=shard_options, total_num_shards=2, states=states
44+
)
45+
elastic_checkpoint.save_elastic_iterator(temp_dir, iterator)
46+
47+
file0 = temp_dir / "shard_state_0-of-2.json"
48+
self.assertTrue(file0.exists())
49+
self.assertEqual(
50+
file0.read_text(),
51+
json.dumps({"val": 0, "total_num_shards": 2}, indent=4),
52+
)
53+
file1 = temp_dir / "shard_state_1-of-2.json"
54+
self.assertTrue(file1.exists())
55+
self.assertEqual(
56+
file1.read_text(),
57+
json.dumps({"val": 1, "total_num_shards": 2}, indent=4),
58+
)
59+
60+
iterator_to_restore = MockElasticIterDatasetIterator(
61+
shard_options=shard_options, total_num_shards=2
62+
)
63+
elastic_checkpoint.restore_elastic_iterator(temp_dir, iterator_to_restore)
64+
self.assertEqual(
65+
iterator_to_restore.updated_states,
66+
{
67+
0: {"val": 0, "total_num_shards": 2},
68+
1: {"val": 1, "total_num_shards": 2},
69+
},
70+
)
71+
72+
def test_restore_elastic_iterator_with_multiple_processes(self):
73+
temp_dir = epath.Path(self.create_tempdir().full_path)
74+
# Process 0
75+
shard_options_0 = sharding.ShardOptions(shard_index=0, shard_count=2)
76+
states = {
77+
0: {"val": 0},
78+
1: {"val": 1},
79+
2: {"val": 2},
80+
}
81+
iterator_0 = MockElasticIterDatasetIterator(
82+
shard_options=shard_options_0, total_num_shards=3, states=states
83+
)
84+
# In reality save_elastic_iterator will be called in each process, but
85+
# get_state() should return all states, so we only need to call it once
86+
# to create checkpoint files.
87+
elastic_checkpoint.save_elastic_iterator(temp_dir, iterator_0)
88+
89+
# Check files are written
90+
self.assertTrue((temp_dir / "shard_state_0-of-3.json").exists())
91+
self.assertTrue((temp_dir / "shard_state_1-of-3.json").exists())
92+
self.assertTrue((temp_dir / "shard_state_2-of-3.json").exists())
93+
94+
# Restore for process 0, responsible for shards 0 and 2.
95+
iterator_to_restore_0 = MockElasticIterDatasetIterator(
96+
shard_options=shard_options_0, total_num_shards=3
97+
)
98+
elastic_checkpoint.restore_elastic_iterator(temp_dir, iterator_to_restore_0)
99+
self.assertEqual(
100+
iterator_to_restore_0.updated_states,
101+
{
102+
0: {"val": 0, "total_num_shards": 3},
103+
2: {"val": 2, "total_num_shards": 3},
104+
},
105+
)
106+
107+
# Restore for process 1, responsible for shard 1.
108+
shard_options_1 = sharding.ShardOptions(shard_index=1, shard_count=2)
109+
iterator_to_restore_1 = MockElasticIterDatasetIterator(
110+
shard_options=shard_options_1, total_num_shards=3
111+
)
112+
elastic_checkpoint.restore_elastic_iterator(temp_dir, iterator_to_restore_1)
113+
self.assertEqual(
114+
iterator_to_restore_1.updated_states,
115+
{
116+
1: {"val": 1, "total_num_shards": 3},
117+
},
118+
)
119+
120+
121+
if __name__ == "__main__":
122+
absltest.main()

grain/_src/python/checkpoint/handler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""This module provides a PyGrain CheckpointHandler for integration with Orbax."""
15+
1516
import dataclasses
1617
import json
1718
from typing import Any, Optional, TypeVar
1819

1920
from etils import epath
2021
from grain._src.core import sharding
2122
from grain._src.python import data_loader
23+
from grain._src.python.checkpoint import elastic_checkpoint
2224
from grain._src.python.dataset import dataset
25+
from grain._src.python.dataset import elastic_iterator
2326

2427
IteratorType = TypeVar(
2528
"IteratorType", data_loader.DataLoaderIterator, dataset.DatasetIterator
@@ -41,6 +44,9 @@ def save(
4144
"""Saves the given iterator to the checkpoint in `directory`."""
4245
item = item or args.item # pytype:disable=attribute-error
4346
if isinstance(item, dataset.DatasetIterator):
47+
if isinstance(item, elastic_iterator.ElasticIterDatasetIterator):
48+
elastic_checkpoint.save_elastic_iterator(directory, item)
49+
return
4450
state = json.dumps(item.get_state(), indent=4)
4551
else:
4652
state = item.get_state().decode()
@@ -56,6 +62,9 @@ def restore(
5662
) -> IteratorType:
5763
"""Restores the given iterator from the checkpoint in `directory`."""
5864
item = item or args.item # pytype:disable=attribute-error
65+
if isinstance(item, elastic_iterator.ElasticIterDatasetIterator):
66+
elastic_checkpoint.restore_elastic_iterator(directory, item)
67+
return item
5968
process_index, process_count = sharding.get_process_index_and_count()
6069
filename = directory / f"process_{process_index}-of-{process_count}.json"
6170
if not filename.exists():
@@ -105,6 +114,5 @@ class CheckpointSave(ocp.args.CheckpointArgs):
105114
class CheckpointRestore(ocp.args.CheckpointArgs):
106115
item: Any
107116

108-
109117
except (ImportError, TypeError, AttributeError):
110118
pass

grain/_src/python/dataset/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ py_test(
181181
":elastic_iterator",
182182
"//grain/_src/core:sharding",
183183
"//grain/_src/python:options",
184+
"//grain/_src/python/checkpoint:elastic_checkpoint",
185+
"//grain/_src/python/checkpoint:handler",
184186
"//grain/_src/python/testing:experimental",
187+
"//third_party/py/orbax/checkpoint",
185188
"@abseil-py//absl/testing:absltest",
186189
"@abseil-py//absl/testing:parameterized",
187190
"@pypi//numpy:pkg",

0 commit comments

Comments
 (0)