diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index ea061c52..df6c245b 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -6,7 +6,6 @@ # # Llama 2 is licensed under the LLAMA 2 Community License, # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -import torch from zeroband.config import Config from zeroband.models.llama.model import ModelArgs, Transformer @@ -95,8 +94,6 @@ def make_model( config: Config, vocab_size: int, - dtype: torch.dtype, - device: torch.device ) -> tuple[Transformer, ModelArgs]: """ Constructs a model instance according to the supplied configuration and target vocab size @@ -114,4 +111,4 @@ def make_model( model_config.max_seq_len = config.data.seq_length model_config.attn_fn = config.hardware.attn_fn - return Transformer(model_config, dtype=dtype, device=device), model_config + return Transformer(model_config), model_config diff --git a/src/zeroband/models/llama/model.py b/src/zeroband/models/llama/model.py index ea782551..e73c3fbf 100644 --- a/src/zeroband/models/llama/model.py +++ b/src/zeroband/models/llama/model.py @@ -170,7 +170,7 @@ def seqlens_to_docs_tensor(seqlens: list[torch.Tensor]) -> torch.Tensor: return torch.stack([torch.repeat_interleave(torch.arange(len(seq), device=seq.device), seq) for seq in seqlens]) -def create_block_mask_from_seqlens(seqlens: list[torch.Tensor], dtype: torch.dtype, device: torch.device) -> BlockMask: +def create_block_mask_from_seqlens(seqlens: list[torch.Tensor]) -> BlockMask: """Creates a block mask from a list of sequence lengths. Example: @@ -183,7 +183,7 @@ def create_block_mask_from_seqlens(seqlens: list[torch.Tensor], dtype: torch.dty [0 0 1 1 0] # Second token of doc 1 can see both tokens of doc 1 [0 0 0 0 1]] # Token of doc 2 can only see itself """ - docs = seqlens_to_docs_tensor(seqlens).to(dtype=dtype, device=device) + docs = seqlens_to_docs_tensor(seqlens).to("cuda") batch_size, max_seq_len = docs.shape def document_causal_mask(b, h, q_idx, kv_idx): @@ -197,7 +197,7 @@ def document_causal_mask(b, h, q_idx, kv_idx): None, max_seq_len, max_seq_len, - device=device.type, + device="cuda", _compile=True, BLOCK_SIZE=max_seq_len if max_seq_len < _DEFAULT_SPARSE_BLOCK_SIZE else _DEFAULT_SPARSE_BLOCK_SIZE, ) @@ -222,17 +222,17 @@ class Attention(nn.Module): """ - def __init__(self, model_args: ModelArgs, dtype: torch.dtype, device: torch.device): + def __init__(self, model_args: ModelArgs): super().__init__() self.n_heads = model_args.n_heads self.n_kv_heads = model_args.n_heads if model_args.n_kv_heads is None else model_args.n_kv_heads self.n_rep = self.n_heads // self.n_kv_heads self.head_dim = model_args.dim // model_args.n_heads - self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False, dtype=dtype, device=device) - self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False, dtype=dtype, device=device) - self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False, dtype=dtype, device=device) - self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False, dtype=dtype, device=device) + self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False) self.attn_fn = model_args.attn_fn @@ -342,8 +342,6 @@ def __init__( hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float], - dtype: torch.dtype, - device: torch.device ): super().__init__() hidden_dim = int(2 * hidden_dim / 3) @@ -352,9 +350,9 @@ def __init__( hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.w1 = nn.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) - self.w2 = nn.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device) - self.w3 = nn.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) def forward(self, x: torch.Tensor, flop_counter: FlopCounter = FlopCounter()): flop_counter.track_linear(self.w1, x) @@ -400,23 +398,22 @@ class TransformerBlock(nn.Module): """ - def __init__(self, layer_id: int, model_args: ModelArgs, dtype: torch.dtype, device: torch.device): + def __init__(self, layer_id: int, model_args: ModelArgs): super().__init__() self.n_heads = model_args.n_heads self.dim = model_args.dim - self.attention = Attention(model_args, dtype=dtype, device=device) + self.attention = Attention(model_args) self.feed_forward = FeedForward( dim=model_args.dim, hidden_dim=4 * model_args.dim, multiple_of=model_args.multiple_of, ffn_dim_multiplier=model_args.ffn_dim_multiplier, - dtype=dtype, device=device ) self.layer_id = layer_id self.num_layers = model_args.n_layers - self.attention_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps, dtype=dtype, device=device) - self.ffn_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps, dtype=dtype, device=device) + self.attention_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps) if model_args.depth_init: self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 @@ -484,7 +481,7 @@ class Transformer(nn.Module): """ - def __init__(self, model_args: ModelArgs, dtype: torch.dtype, device: torch.device): + def __init__(self, model_args: ModelArgs): super().__init__() self.model_args = model_args self.vocab_size = model_args.vocab_size @@ -499,15 +496,15 @@ def __init__(self, model_args: ModelArgs, dtype: torch.dtype, device: torch.devi # a seed checkpoint rather than calling init_weights, we need freqs_cis to be # initialized by the checkpoint, or we need to add a separate initializer for # just the non-persistent buffers that is called after loading checkpoints. - self.register_buffer("freqs_cis", self._precompute_freqs_cis(dtype=dtype, device=device), persistent=True) + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.n_layers): - self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args, dtype=dtype, device=device) + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) - self.norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps, dtype=dtype, device=device) + self.norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps) - self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False, dtype=dtype, device=device) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) self.init_weights() def init_weights(self): @@ -522,6 +519,8 @@ def init_weights(self): ``init_weights``. We only call it in the constructor of this ``Transformer`` root module to avoid reinitializing tensors. """ + with torch.device(self.freqs_cis.device): + self.freqs_cis = self._precompute_freqs_cis() if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): @@ -540,14 +539,14 @@ def init_weights(self): b=cutoff_factor * final_out_std, ) - def _precompute_freqs_cis(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + def _precompute_freqs_cis(self) -> torch.Tensor: return precompute_freqs_cis( self.model_args.dim // self.model_args.n_heads, # Need to compute until at least the max token limit for generation # (use 2x max sequence length to be safe) self.model_args.max_seq_len * 2, - self.model_args.rope_theta - ).to(dtype=dtype, device=device) + self.model_args.rope_theta, + ) def forward(self, tokens: torch.Tensor, block_mask: BlockMask | None = None, flop_counter: FlopCounter = FlopCounter()): """ @@ -576,6 +575,20 @@ def forward(self, tokens: torch.Tensor, block_mask: BlockMask | None = None, flo return output + @classmethod + def from_model_args(cls, model_args: ModelArgs) -> "Transformer": + """ + Initialize a Transformer model from a ModelArgs object. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Returns: + Transformer: Transformer model. + + """ + return cls(model_args) + def count_parameters(self, exclude_embedding: bool = False) -> int: """ Counts the number of parameters. diff --git a/src/zeroband/models/norms.py b/src/zeroband/models/norms.py index 650296e4..f1febcf8 100644 --- a/src/zeroband/models/norms.py +++ b/src/zeroband/models/norms.py @@ -21,7 +21,7 @@ from torch.distributed.tensor.experimental import local_map -def build_norm(norm_type: str, dim: int, eps: float, dtype: torch.dtype, device: torch.device): +def build_norm(norm_type: str, dim: int, eps: float = 1e-6): """ Builds the specified normalization layer based on the norm_type. @@ -29,9 +29,7 @@ def build_norm(norm_type: str, dim: int, eps: float, dtype: torch.dtype, device: norm_type (str): The type of normalization layer to build. Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm dim (int): The dimension of the normalization layer. - eps (float, optional): The epsilon value for numerical stability. - dtype: The data type to use for the parameter tensor - device: The device to place the layer on + eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. Returns: The built normalization layer. @@ -42,13 +40,13 @@ def build_norm(norm_type: str, dim: int, eps: float, dtype: torch.dtype, device: norm_type = norm_type.lower() # Normalize to lowercase if norm_type == "layernorm": - return nn.LayerNorm(dim, eps=eps, bias=False, dtype=dtype, device=device) + return nn.LayerNorm(dim, eps=eps, bias=False) elif norm_type == "np_layernorm": - return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False, dtype=dtype, device=device) + return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) elif norm_type == "rmsnorm": - return RMSNorm(dim, eps=eps, dtype=dtype, device=device) + return RMSNorm(dim, eps=eps) elif norm_type == "fused_rmsnorm": - return FusedRMSNorm(dim, eps=eps, dtype=dtype, device=device) + return FusedRMSNorm(dim, eps=eps) else: raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") @@ -59,13 +57,11 @@ class FusedRMSNorm(nn.Module): def __init__( self, dim: int, - eps: float, - dtype: torch.dtype, - device: torch.device + eps: float = 1e-6, ): super().__init__() self.eps = eps - self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device)) + self.weight = nn.Parameter(torch.ones(dim)) self.fused_rms_norm_fn = fused_rms_norm_fn def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -86,9 +82,7 @@ class RMSNorm(nn.Module): Args: dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. - dtype: The data type to use - device: The torch device to place the layer on + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. Attributes: eps (float): A small value added to the denominator for numerical stability. @@ -96,10 +90,10 @@ class RMSNorm(nn.Module): """ - def __init__(self, dim: int, eps: float, dtype: torch.dtype, device: torch.device): + def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps - self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device)) + self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x: torch.Tensor): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 2db6311c..fa42c3fd 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -124,7 +124,6 @@ def run_inner_steps( num_param_scalars = model.count_parameters() for _inner_step in range(num_inner_steps): - #torch.cuda.memory._record_memory_history(max_entries=100000) train_profiler.start_session("inner_step") flop_counter = FlopCounter() @@ -234,9 +233,6 @@ def run_inner_steps( memory_profiler.step() train_profiler.end_session() - #torch.cuda.memory._dump_snapshot('snapshot.pickle') - #torch.cuda.memory._record_memory_history(enabled=None) - def compute_crc32(tensor: torch.Tensor) -> int: tensor_cpu = tensor.detach().cpu() @@ -522,7 +518,7 @@ def make_shared_state(outer_parameters: Dict[str, torch.nn.Parameter], return shared_state -def train(logger: Logger, config: Config, mpi_config: Optional[MPIConfig], dtype: torch.dtype, device: torch.device): +def train(logger: Logger, config: Config, mpi_config: Optional[MPIConfig], device: torch.device): grad_accum_steps = calc_gradient_accumulation_steps( config.train.batch_size, config.hardware.micro_batch_size, mpi_config ) @@ -546,8 +542,6 @@ def train(logger: Logger, config: Config, mpi_config: Optional[MPIConfig], dtype model, model_config = make_model( config, vocab_size=tokenizer_info.vocab_size, - dtype=dtype, - device=device, ) num_param_scalars = model.count_parameters() logger.info(f"Number of parameters: {num_param_scalars}") @@ -711,16 +705,16 @@ def train(logger: Logger, config: Config, mpi_config: Optional[MPIConfig], dtype continue local_world_size = communicator.get_attribute(Attribute.LOCAL_WORLD_SIZE) - #if local_world_size < 2: - # logger.info("Waiting for more workers to join...") - # time.sleep(1) - # continue + if local_world_size < 2: + logger.info("Waiting for more workers to join...") + time.sleep(1) + continue if topology_updated: logger.info("Optimizing Topology...") while True: try: - # communicator.optimize_topology() # may raise an error if it fails + communicator.optimize_topology() # may raise an error if it fails break except PCCLError as e: print(f"[Peer] OptimizeTopology failed => {e}. Retrying...") @@ -835,8 +829,7 @@ def main(): device = torch.device(f'cuda:{torch.cuda.current_device()}') logger.info(f"Using device: {torch.cuda.get_device_name(device)}") - dtype = torch.bfloat16 # TODO: MAKE CONFIGURABLE - train(logger, config, mpi_config, dtype, device) + train(logger, config, mpi_config, device) if __name__ == "__main__":