diff --git a/requirements_common.txt b/requirements_common.txt index 702c2bbe8..f57299c9f 100644 --- a/requirements_common.txt +++ b/requirements_common.txt @@ -24,6 +24,7 @@ sglang-router wandb swanlab +trackio math-verify openai diff --git a/roll/utils/tracking.py b/roll/utils/tracking.py index 785eca881..217fc472f 100644 --- a/roll/utils/tracking.py +++ b/roll/utils/tracking.py @@ -168,7 +168,55 @@ def create_tracker(tracker_name: str, config: dict, **kwargs) -> BaseTracker: tracker_cls = tracker_registry[tracker_name] return tracker_cls(config, **kwargs) +class TrackioTracker(BaseTracker): + + def __init__(self, config: dict, **kwargs): + self.config = config + + project = kwargs.pop("project", None) + name = kwargs.pop("name", None) + group = kwargs.pop("group", None) + space_id = kwargs.pop("space_id", None) + dataset_id = kwargs.pop("dataset_id", None) + tags = kwargs.pop("tags", None) + + auto_log_gpu = kwargs.pop("auto_log_gpu", True) + gpu_log_interval = kwargs.pop("gpu_log_interval", 2) + + import trackio + + if space_id: + logger.info(f"[Trackio] Using HF Space: {space_id}") + if dataset_id: + logger.info(f"[Trackio] Syncing to dataset: {dataset_id}") + + self.run = trackio.init( + project=project, + name=name, + group=group, + config=config, + space_id=space_id, + dataset_id=dataset_id, + tags=tags, + auto_log_gpu=auto_log_gpu, + gpu_log_interval=gpu_log_interval, + ) + + @strip_at_tag_in_log + def log(self, values: dict, step: Optional[int], **kwargs): + if step is not None: + values = dict(values) + values["step"] = step + self.run.log(values) + + def log_system(self, values: dict): + self.run.log_system(values) + + def finish(self): + self.run.finish() + tracker_registry["tensorboard"] = TensorBoardTracker tracker_registry["wandb"] = WandbTracker tracker_registry["stdout"] = StdoutTracker tracker_registry["swanlab"] = SwanlabTracker +tracker_registry["trackio"] = TrackioTracker