diff --git a/src/zeroband/config.py b/src/zeroband/config.py index aa230d63..905643c6 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -20,7 +20,7 @@ class DataConfig(BaseConfig): num_workers: int = 1 max_train_samples: int | None = None max_eval_samples: int | None = None - dataset_ratio: str | None = None + dataset_ratio: str = "100" data_rank: int | None = None data_world_size: int | None = None reverse_data_files: bool = False diff --git a/src/zeroband/train.py b/src/zeroband/train.py index ea1837bc..fa42c3fd 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -632,6 +632,7 @@ def train(logger: Logger, config: Config, mpi_config: Optional[MPIConfig], devic num_inner_steps = config.diloco.inner_steps if config.diloco is not None else 1 + logger.info(f"Attempting to connect PCCL to {config.pccl.ccoip_host}") # initialize PCCL communicator = Communicator(config.pccl.ccoip_host, mpi_config.mpi_rank if mpi_config is not None else 0) communicator.connect(n_attempts=15)