From 8d05958bd616c0fa42b49c8d929733c1fabacd07 Mon Sep 17 00:00:00 2001 From: skishore Date: Fri, 27 Mar 2026 18:06:13 +0000 Subject: [PATCH 1/2] fix lint errors --- .github/ISSUE_TEMPLATE/bug_report.md | 2 +- .github/workflows/rocm-ci.yml | 6 +- .gitignore | 2 +- Makefile | 4 +- README.md | 64 ++-- apex/RNN/RNNBackend.py | 58 ++-- apex/RNN/cells.py | 14 +- apex/RNN/models.py | 2 +- apex/amp/_process_optimizer.py | 2 +- apex/amp/frontend.py | 2 +- apex/amp/lists/functional_overrides.py | 2 +- apex/contrib/bottleneck/bottleneck.py | 6 +- apex/contrib/conv_bias_relu/__init__.py | 2 +- apex/contrib/csrc/bottleneck/bottleneck.cpp | 4 +- .../csrc/conv_bias_relu/conv_bias_relu.cpp | 16 +- .../conv_bias_relu/conv_bias_relu_rocm.cpp | 12 +- apex/contrib/csrc/fmha/fmha_api.cpp | 12 +- apex/contrib/csrc/fmha/src/fmha.h | 10 +- apex/contrib/csrc/fmha/src/fmha/gemm.h | 4 +- apex/contrib/csrc/fmha/src/fmha/gmem_tile.h | 10 +- .../csrc/fmha/src/fmha/kernel_traits.h | 4 +- apex/contrib/csrc/fmha/src/fmha/mask.h | 4 +- apex/contrib/csrc/fmha/src/fmha/smem_tile.h | 50 +-- apex/contrib/csrc/fmha/src/fmha/softmax.h | 10 +- apex/contrib/csrc/fmha/src/fmha/utils.h | 20 +- .../src/fmha_dgrad_fp16_128_64_kernel.sm80.cu | 4 +- .../src/fmha_dgrad_fp16_256_64_kernel.sm80.cu | 4 +- .../src/fmha_dgrad_fp16_384_64_kernel.sm80.cu | 4 +- .../src/fmha_dgrad_fp16_512_64_kernel.sm80.cu | 6 +- .../fmha/src/fmha_dgrad_kernel_1xN_reload.h | 4 +- .../src/fmha_dgrad_kernel_1xN_reload_nl.h | 26 +- .../src/fmha_fprop_fp16_128_64_kernel.sm80.cu | 22 +- .../src/fmha_fprop_fp16_256_64_kernel.sm80.cu | 22 +- .../src/fmha_fprop_fp16_384_64_kernel.sm80.cu | 22 +- .../src/fmha_fprop_fp16_512_64_kernel.sm80.cu | 24 +- .../csrc/fmha/src/fmha_fprop_kernel_1xN.h | 20 +- apex/contrib/csrc/fmha/src/fmha_kernel.h | 14 +- .../csrc/fmha/src/fmha_noloop_reduce.cu | 6 +- apex/contrib/csrc/fmha/src/fmha_utils.h | 4 +- apex/contrib/csrc/groupbn/batch_norm.cu | 8 +- apex/contrib/csrc/groupbn/batch_norm.h | 2 +- .../csrc/groupbn/batch_norm_add_relu.cu | 8 +- .../csrc/index_mul_2d/index_mul_2d_cuda.cpp | 8 +- .../index_mul_2d/index_mul_2d_cuda_kernel.cu | 110 +++--- apex/contrib/csrc/layer_norm/ln.h | 2 +- apex/contrib/csrc/layer_norm/ln_api.cpp | 14 +- .../csrc/layer_norm/ln_bwd_kernels.cuh | 4 +- .../layer_norm/ln_bwd_semi_cuda_kernel.cu | 14 +- .../csrc/layer_norm/ln_fwd_cuda_kernel.cu | 14 +- .../csrc/layer_norm/ln_fwd_kernels.cuh | 2 +- .../csrc/layer_norm/ln_kernel_traits.h | 26 +- apex/contrib/csrc/layer_norm/ln_utils.cuh | 24 +- .../encdec_multihead_attn_cuda.cu | 310 ++++++++--------- .../encdec_multihead_attn_norm_add_cuda.cu | 318 +++++++++--------- .../multihead_attn_frontend.cpp | 4 +- ..._multihead_attn_bias_additive_mask_cuda.cu | 264 +++++++-------- .../self_multihead_attn_bias_cuda.cu | 240 ++++++------- .../self_multihead_attn_cuda.cu | 260 +++++++------- .../self_multihead_attn_norm_add_cuda.cu | 256 +++++++------- apex/contrib/csrc/multihead_attn/softmax.cuh | 2 +- .../multihead_attn/strided_batched_gemm.cuh | 34 +- apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh | 6 +- .../csrc/optimizers/fused_adam_cuda_kernel.cu | 4 +- .../csrc/optimizers/fused_lamb_cuda_kernel.cu | 2 +- .../multi_tensor_distopt_lamb_kernel.cu | 2 +- .../csrc/peer_memory/peer_memory_cuda.cu | 10 +- .../csrc/peer_memory/peer_memory_cuda.cuh | 2 +- .../csrc/transducer/transducer_joint.cpp | 10 +- .../transducer/transducer_joint_kernel.cu | 190 +++++------ .../csrc/transducer/transducer_loss.cpp | 10 +- .../csrc/transducer/transducer_loss_kernel.cu | 210 ++++++------ .../perf_test_multihead_attn.py | 32 +- apex/contrib/fmha/fmha.py | 6 +- apex/contrib/groupbn/batch_norm.py | 4 +- apex/contrib/index_mul_2d/index_mul_2d.py | 12 +- apex/contrib/multihead_attn/README.md | 2 +- .../optimizers/distributed_fused_adam.py | 4 +- apex/contrib/optimizers/fp16_optimizer.py | 2 +- apex/contrib/optimizers/fused_sgd.py | 16 +- .../peer_memory/peer_halo_exchanger_1d.py | 2 +- apex/contrib/sparsity/README.md | 8 +- apex/contrib/sparsity/asp.py | 14 +- .../permutation_search_kernels.cu | 4 +- .../permutation_utilities.py | 8 +- apex/contrib/sparsity/sparse_masklib.py | 6 +- .../conv_bias_relu/test_conv_bias_relu.py | 12 +- apex/contrib/test/fmha/test_fmha.py | 32 +- .../test/groupbn/test_groupbn_channel_last.py | 4 +- .../test/index_mul_2d/test_index_mul_2d.py | 12 +- .../test_encdec_multihead_attn.py | 74 ++-- .../test_encdec_multihead_attn_norm_add.py | 54 +-- .../test_fast_self_multihead_attn_bias.py | 56 +-- .../multihead_attn/test_mha_fused_softmax.py | 12 +- .../test_self_multihead_attn.py | 2 +- .../test_self_multihead_attn_norm_add.py | 48 +-- .../optimizers/test_distributed_fused_lamb.py | 8 +- apex/contrib/test/test_label_smoothing.py | 8 +- .../test/transducer/test_transducer_joint.py | 44 +-- .../test/transducer/test_transducer_loss.py | 38 +-- .../contrib/test/transducer/transducer_ref.py | 24 +- apex/contrib/transducer/transducer.py | 74 ++-- apex/fp16_utils/README.md | 2 +- apex/fp16_utils/fp16_optimizer.py | 96 +++--- apex/fp16_utils/fp16util.py | 4 +- apex/fp16_utils/loss_scaler.py | 24 +- apex/fused_dense/fused_dense.py | 14 +- apex/optimizers/fused_lars.py | 4 +- apex/optimizers/fused_mixed_precision_lamb.py | 12 +- apex/parallel/LARC.py | 6 +- apex/transformer/functional/fused_rope.py | 28 +- compatibility/fused_layer_norm_cuda.py | 8 +- compatibility/mlp_cuda.py | 8 +- csrc/fused_dense_cuda.cu | 86 ++--- csrc/layer_norm_cuda_kernel.cu | 14 +- csrc/megatron/fused_bias_swiglu_cuda.cu | 18 +- ...d_weight_gradient_dense_16bit_prec_cuda.cu | 8 +- .../fused_weight_gradient_dense_cuda.cu | 2 +- csrc/megatron/generic_scaled_masked_softmax.h | 58 ++-- .../generic_scaled_masked_softmax_cuda.cu | 16 +- csrc/megatron/scaled_masked_softmax.h | 66 ++-- csrc/megatron/scaled_masked_softmax_cpu.cpp | 16 +- csrc/megatron/scaled_masked_softmax_cuda.cu | 16 +- csrc/megatron/scaled_softmax_cpu.cpp | 18 +- csrc/megatron/scaled_softmax_cuda.cu | 16 +- .../scaled_upper_triang_masked_softmax.h | 66 ++-- ...scaled_upper_triang_masked_softmax_cpu.cpp | 16 +- ...scaled_upper_triang_masked_softmax_cuda.cu | 20 +- csrc/mlp_cuda.cu | 2 +- csrc/multi_tensor_l2norm_kernel.cu | 6 +- csrc/multi_tensor_lars.cu | 6 +- csrc/syncbn.cpp | 6 +- csrc/welford.cu | 2 +- docs/source/amp.rst | 10 +- docs/source/fp16_utils.rst | 8 +- docs/source/index.rst | 2 +- examples/dcgan/main_amp.py | 2 +- examples/imagenet/README.md | 2 +- examples/simple/distributed/README.md | 2 +- op_builder/__init__.py | 2 +- op_builder/all_ops.py | 2 +- op_builder/amp_C.py | 2 +- op_builder/apex_C.py | 2 +- op_builder/builder.py | 6 +- op_builder/distributed_adam.py | 2 +- op_builder/distributed_lamb.py | 2 +- op_builder/fast_multihead_attn.py | 4 +- op_builder/focal_loss.py | 2 +- op_builder/fused_adam.py | 2 +- op_builder/fused_conv_bias_relu.py | 2 +- op_builder/fused_lamb.py | 2 +- op_builder/mlp.py | 2 +- op_builder/nccl_allocator.py | 2 +- op_builder/nccl_p2p.py | 2 +- op_builder/peer_memory.py | 2 +- op_builder/transducer_joint.py | 8 +- op_builder/transducer_loss.py | 4 +- scripts/jit_module.py | 8 +- setup.py | 10 +- tests/L0/run_amp/test_basic_casts.py | 2 +- tests/L0/run_amp/test_cache.py | 20 +- tests/L0/run_amp/test_checkpointing.py | 6 +- tests/L0/run_amp/test_rnn.py | 2 +- tests/L0/run_fused_dense/test_gelu.py | 2 +- .../test_fused_layer_norm.py | 26 +- tests/L0/run_optimizers/test_adam.py | 12 +- .../L0/run_optimizers/test_fused_novograd.py | 10 +- tests/L0/run_test.py | 2 +- tests/L0/run_transformer/test_layers.py | 2 +- tests/L1/common/main_amp.py | 10 +- .../DDP/ddp_race_condition_test.py | 2 +- tests/distributed/run_rocm_distributed.sh | 2 +- .../two_gpu_test_different_batch_size.py | 2 +- .../synced_batchnorm/two_gpu_unit_test.py | 2 +- .../distributed/synced_batchnorm/unit_test.sh | 2 +- tests/jit_build/build.sh | 20 +- tests/jit_build/count_failed_unit_tests.py | 4 +- tests/jit_build/run_tests.sh | 2 +- tests/jit_build/scripts/run.sh | 10 +- tests/test_extension_import.py | 22 +- 179 files changed, 2209 insertions(+), 2207 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 63d76737f..a4afce030 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -14,7 +14,7 @@ assignees: '' Please list the *minimal* steps or provide a code snippet for us to be able to reproduce the bug. A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports. ---> +--> **Expected Behavior** diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index b5aa06faf..874222152 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -3,7 +3,7 @@ name: Apex ROCm CI on: pull_request: types: [opened, synchronize, ready_for_review] - branches: + branches: - master - release/1.8.0 - release/1.9.0 @@ -60,7 +60,7 @@ jobs: # Uses the specified branch on manual runs; defaults to the PR/Push context otherwise ref: ${{ github.event_name == 'workflow_dispatch' && inputs.apex_gitref || '' }} submodules: recursive - + - name: Pull Docker Image run: | docker pull ${{ env.DOCKER_IMAGE }} @@ -123,7 +123,7 @@ jobs: with: name: apex-wheel path: dist/ - + - name: Pull Docker Image run: | docker pull ${{ env.DOCKER_IMAGE }} diff --git a/.gitignore b/.gitignore index da67982aa..99cbed96d 100644 --- a/.gitignore +++ b/.gitignore @@ -147,7 +147,7 @@ dmypy.json cython_debug/ *.hip *_hip.* -*hip* +*hip* #file temporarily created for build process diff --git a/Makefile b/Makefile index 99e44805f..ebd1c4f5e 100644 --- a/Makefile +++ b/Makefile @@ -9,9 +9,9 @@ clean: # This will remove ALL build folders. @test -d apex.egg-info/ && echo "Deleting apex.egg-info folder" || true @test -d apex.egg-info/ && rm -r apex.egg-info/ || true - $(PYTHON) scripts/clean.py # remove the apex extensions installed at torch extensions folder + $(PYTHON) scripts/clean.py # remove the apex extensions installed at torch extensions folder aiter: $(PIP) uninstall -y aiter cd third_party/aiter && $(PIP) install . --no-build-isolation --no-deps - + diff --git a/README.md b/README.md index dfd26c557..de4143cb8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Introduction -This repository holds ROCm variant of Nvidia's Apex: https://github.com/NVIDIA/apex. +This repository holds ROCm variant of Nvidia's Apex: https://github.com/NVIDIA/apex. The aim of Apex repository is to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intent of Apex is to make up-to-date utilities available to users as quickly as possible. @@ -118,19 +118,19 @@ pip install . --no-build-isolation ### Supported Versions | ``APEX Version`` | ``APEX branch`` | ``Torch Version`` | |------------------|-----------------|-------------------| -| ``1.9.0`` | release/1.9.0 | ``2.9`` | -| ``1.8.0`` | release/1.8.0 | ``2.8`` | -| ``1.7.0`` | release/1.7.0 | ``2.7`` | -| ``1.6.0`` | release/1.6.0 | ``2.6`` | -| ``1.5.0`` | release/1.5.0 | ``2.5`` | -| ``1.4.0`` | release/1.4.0 | ``2.4`` | -| ``1.3.0`` | release/1.3.0 | ``2.3`` | -| ``1.2.0`` | release/1.2.0 | ``2.2`` | +| ``1.9.0`` | release/1.9.0 | ``2.9`` | +| ``1.8.0`` | release/1.8.0 | ``2.8`` | +| ``1.7.0`` | release/1.7.0 | ``2.7`` | +| ``1.6.0`` | release/1.6.0 | ``2.6`` | +| ``1.5.0`` | release/1.5.0 | ``2.5`` | +| ``1.4.0`` | release/1.4.0 | ``2.4`` | +| ``1.3.0`` | release/1.3.0 | ``2.3`` | +| ``1.2.0`` | release/1.2.0 | ``2.2`` | | ``1.1.0`` | release/1.1.0 | ``2.1`` | | ``1.0.0`` | release/1.0.0 | ``2.0`` and older | -The relation between APEX and ROCm PyTorch is maintained in file `related_commits` in [ROCm PyTorch release branches](https://github.com/ROCm/pytorch/branches/all?query=release) in the following format. +The relation between APEX and ROCm PyTorch is maintained in file `related_commits` in [ROCm PyTorch release branches](https://github.com/ROCm/pytorch/branches/all?query=release) in the following format. ``` ubuntu|pytorch|apex|release/1.0.0|06c33eee43f7a22f3ed7d9c3e5be0ddd757dc345|https://github.com/ROCmSoftwarePlatform/apex @@ -178,11 +178,11 @@ The following extensions are supported: | transducer_loss_cuda | APEX_BUILD_TRANSDUCER_LOSS=1 | APEX_BUILD_CUDA_OPS=1 | | xentropy_cuda | APEX_BUILD_XENTROPY=1 | APEX_BUILD_CUDA_OPS=1 | -For example, to build FUSED_DENSE​ you can use the following command: +For example, to build FUSED_DENSE​ you can use the following command: ``` APEX_BUILD_FUSED_DENSE​=1 pip install . --no-build-isolation ``` -This will pre-build and install FUSED_DENSE​ module and rest of the modules are installed to be JIT built and loaded at runtime. +This will pre-build and install FUSED_DENSE​ module and rest of the modules are installed to be JIT built and loaded at runtime. Aiter backend can be built and used for fused rope. To install aiter: ``` @@ -193,10 +193,10 @@ To use aiter in fused rope, you can use the flag ```USE_ROCM_AITER_ROPE_BACKEND= ### To add a new module into jit loader -What is JIT (just-in-time) load? Just-in-time load helps to build the specific modules that are used without needing to build all modules during installation time. This helps to significantly reduce installation time. Without JIT load, it would take roughtly 30 minutes to install apex. With JIT load, it takes less than 1 minute to install apex. +What is JIT (just-in-time) load? Just-in-time load helps to build the specific modules that are used without needing to build all modules during installation time. This helps to significantly reduce installation time. Without JIT load, it would take roughtly 30 minutes to install apex. With JIT load, it takes less than 1 minute to install apex. -A python script is provided to ease the process of adding a new module to JIT load. -For this, the user must create C++/CUDA source code for a new apex module in either csrc or apex/contrib/csrc folder. +A python script is provided to ease the process of adding a new module to JIT load. +For this, the user must create C++/CUDA source code for a new apex module in either csrc or apex/contrib/csrc folder. This script helps to create a builder and a loader for the apex module. The builder creates the .so file for the apex module (during installation or jit load time) and the loader loads the .so file when the module is imported. @@ -204,25 +204,25 @@ To run the script: ``` python scripts/jit_module.py -``` +``` The user should provide the name used to import the module i.e. import fused_bias_swiglu. If the user does not provide the module name, the script will ask for the module name ``` What is the name of the module? -``` +``` -The script is interactive and asks two questions +The script is interactive and asks two questions 1. Is this a CUDA module? (Y/n) -2. Enter the sources (comma separated) Press Enter to skip +2. Enter the sources (comma separated) Press Enter to skip If the user answers yes to cuda module, it builds with CUDAOpBuilder otherwise it builds as a cpu operation with CPUOpBuilder. The default is cuda operation. The user must mention the list of .cpp, .h, .cu files used to compile the module as a comma separated list. This argument is used to define the return value of sources() method in the builder module. -This will be used to also find the list of directories (include_paths() method) i.e. -I flag in g++ compiler. -The user can decide to skip the list of sources and add it manually to the builder file created by the script. +This will be used to also find the list of directories (include_paths() method) i.e. -I flag in g++ compiler. +The user can decide to skip the list of sources and add it manually to the builder file created by the script. -e.g. +e.g. ``` python scripts/jit_module.py fused_bias_swiglu 1. Is this a CUDA module? (Y/n) y @@ -250,7 +250,7 @@ apex/ # repo root ``` -The user must not edit the loader code. +The user must not edit the loader code. The script creates an initial builder code and the users can edit the methods in the module. @@ -262,7 +262,7 @@ The builder module is created in op_builder folder and must override either CPUO | INCLUDE_FLAG | Either APEX_BUILD_CUDA_OPS or APEX_BUILD_CPU_OPS to indicate whether the module will be built for gpu or cpu | | NAME | name of module e.g. fused_bias_swiglu | -| Method | Purpose | Necessary to override | +| Method | Purpose | Necessary to override | |-----------|-----------|-----------| | absolute_name | return the namespace where the module will be installed | Yes | | sources | list of C++/CUDA source files for the module | Yes | @@ -289,11 +289,11 @@ make clean ``` ### Enable hipblasLT on ROCm -hipblasLT is supported only on mi300 (gfx942) only. -python setup.py automatically builds apex with hipblasLT support only if GPU device id is gfx942 -To verify if hipblasLT support is enabled, check the build logs -INFO: IS_HIPBLASLT_SUPPORTED value is True ==> indicates apex is built with hipblasLT support -INFO: IS_HIPBLASLT_SUPPORTED value is False +hipblasLT is supported only on mi300 (gfx942) only. +python setup.py automatically builds apex with hipblasLT support only if GPU device id is gfx942 +To verify if hipblasLT support is enabled, check the build logs +INFO: IS_HIPBLASLT_SUPPORTED value is True ==> indicates apex is built with hipblasLT support +INFO: IS_HIPBLASLT_SUPPORTED value is False ### Linux For performance and full functionality, we recommend installing Apex with @@ -307,7 +307,7 @@ pip install . --no-build-isolation ### [Experimental] Windows `pip install . --no-build-isolation` may work if you were able to build Pytorch from source -on your system. A Python-only build via `pip install --no-build-isolation -v --no-cache-dir .` is more likely to work. +on your system. A Python-only build via `pip install --no-build-isolation -v --no-cache-dir .` is more likely to work. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment. # Release notes @@ -356,7 +356,7 @@ Unit test related Upgraded extensions - Support unscale_grads in transformer Grad scaler - Support amp function in fused dense, mlp -- Support blas backend flag in fused dense +- Support blas backend flag in fused dense - Support not destroying process group for distributed tests - Upgrade fused adam to support parameters - capturable, master weights, grad scaler - Upgrade distributed fused adam to support bias_correction, adam_w_mode, overlap_param_sync, store_params, store_param_remainders, with_scaled_states, nccl_ub @@ -374,7 +374,7 @@ Added extensions - fused bias swiglu - fused gradient accumulator - fused rope - + Upgraded extensions - Support blaslt backend in fused weight gradient dense module diff --git a/apex/RNN/RNNBackend.py b/apex/RNN/RNNBackend.py index a9382e601..88100456e 100644 --- a/apex/RNN/RNNBackend.py +++ b/apex/RNN/RNNBackend.py @@ -17,10 +17,10 @@ def flatten_list(tens_list): """ if not is_iterable(tens_list): return tens_list - + return torch.cat(tens_list, dim=0).view(len(tens_list), *tens_list[0].size() ) - + #These modules always assumes batch_first class bidirectionalRNN(nn.Module): """ @@ -32,7 +32,7 @@ def __init__(self, inputRNN, num_layers=1, dropout = 0): self.fwd = stackedRNN(inputRNN, num_layers=num_layers, dropout = dropout) self.bckwrd = stackedRNN(inputRNN.new_like(), num_layers=num_layers, dropout = dropout) self.rnns = nn.ModuleList([self.fwd, self.bckwrd]) - + #collect hidden option will return all hidden/cell states from entire RNN def forward(self, input, collect_hidden=False): """ @@ -43,7 +43,7 @@ def forward(self, input, collect_hidden=False): fwd_out, fwd_hiddens = list(self.fwd(input, collect_hidden = collect_hidden)) bckwrd_out, bckwrd_hiddens = list(self.bckwrd(input, reverse=True, collect_hidden = collect_hidden)) - + output = torch.cat( [fwd_out, bckwrd_out], -1 ) hiddens = tuple( torch.cat(hidden, -1) for hidden in zip( fwd_hiddens, bckwrd_hiddens) ) @@ -55,7 +55,7 @@ def reset_parameters(self): """ for rnn in self.rnns: rnn.reset_parameters() - + def init_hidden(self, bsz): """ init_hidden() @@ -69,7 +69,7 @@ def detach_hidden(self): """ for rnn in self.rnns: rnn.detachHidden() - + def reset_hidden(self, bsz): """ reset_hidden() @@ -77,25 +77,25 @@ def reset_hidden(self, bsz): for rnn in self.rnns: rnn.reset_hidden(bsz) - def init_inference(self, bsz): + def init_inference(self, bsz): """ init_inference() """ for rnn in self.rnns: rnn.init_inference(bsz) - + #assumes hidden_state[0] of inputRNN is output hidden state #constructor either takes an RNNCell or list of RNN layers -class stackedRNN(nn.Module): +class stackedRNN(nn.Module): """ stackedRNN """ def __init__(self, inputRNN, num_layers=1, dropout=0): super(stackedRNN, self).__init__() - + self.dropout = dropout - + if isinstance(inputRNN, RNNCell): self.rnns = [inputRNN] for i in range(num_layers-1): @@ -105,9 +105,9 @@ def __init__(self, inputRNN, num_layers=1, dropout=0): self.rnns=inputRNN else: raise RuntimeError() - + self.nLayers = len(self.rnns) - + self.rnns = nn.ModuleList(self.rnns) @@ -135,14 +135,14 @@ def forward(self, input, collect_hidden=False, reverse=False): if layer == 0: prev_out = input[seq] - + outs = self.rnns[layer](prev_out) if collect_hidden: hidden_states[layer].append(outs) elif seq == seq_len-1: hidden_states[layer].append(outs) - + prev_out = outs[0] outputs.append(prev_out) @@ -187,20 +187,20 @@ def forward(self, input, collect_hidden=False, reverse=False): hiddens = list( list( flatten_list(seq) for seq in hidden ) for hidden in hidden_states ) - + #Now in format list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) ) #Remove seq_length dimension if not collect_hidden if not collect_hidden: hidden_states = list( entry[0] for entry in hidden_states) return output, hidden_states - + def reset_parameters(self): """ reset_parameters() """ for rnn in self.rnns: rnn.reset_parameters() - + def init_hidden(self, bsz): """ init_hidden() @@ -214,7 +214,7 @@ def detach_hidden(self): """ for rnn in self.rnns: rnn.detach_hidden() - + def reset_hidden(self, bsz): """ reset_hidden() @@ -222,16 +222,16 @@ def reset_hidden(self, bsz): for rnn in self.rnns: rnn.reset_hidden(bsz) - def init_inference(self, bsz): - """ + def init_inference(self, bsz): + """ init_inference() """ for rnn in self.rnns: rnn.init_inference(bsz) class RNNCell(nn.Module): - """ - RNNCell + """ + RNNCell gate_multiplier is related to the architecture you're working with For LSTM-like it will be 4 and GRU-like will be 3. Always assumes input is NOT batch_first. @@ -265,7 +265,7 @@ def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_stat if self.bias: self.b_ih = nn.Parameter(torch.empty(self.gate_size)) self.b_hh = nn.Parameter(torch.empty(self.gate_size)) - + #hidden states for forward self.hidden = [ None for states in range(self.n_hidden_states)] @@ -277,7 +277,7 @@ def new_like(self, new_input_size=None): """ if new_input_size is None: new_input_size = self.input_size - + return type(self)(self.gate_multiplier, new_input_size, self.hidden_size, @@ -286,7 +286,7 @@ def new_like(self, new_input_size=None): self.bias, self.output_size) - + #Use xavier where we can (weights), otherwise use uniform (bias) def reset_parameters(self, gain=1): """ @@ -325,8 +325,8 @@ def init_hidden(self, bsz): tens = a_param.data.new(bsz, hidden_size).zero_() self.hidden[i] = Variable(tens, requires_grad=False) - - + + def reset_hidden(self, bsz): """ reset_hidden() @@ -344,7 +344,7 @@ def detach_hidden(self): raise RuntimeError("Must initialize hidden state before you can detach it") for i, _ in enumerate(self.hidden): self.hidden[i] = self.hidden[i].detach() - + def forward(self, input): """ forward() diff --git a/apex/RNN/cells.py b/apex/RNN/cells.py index 09b08581d..2a267b908 100644 --- a/apex/RNN/cells.py +++ b/apex/RNN/cells.py @@ -6,7 +6,7 @@ from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend -import math +import math class mLSTMRNNCell(RNNCell): @@ -36,7 +36,7 @@ def forward(self, input): self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_mih, self.w_mhh, b_ih=self.b_ih, b_hh=self.b_hh) ) - + if self.output_size != self.hidden_size: self.hidden[0] = F.linear(self.hidden[0], self.w_ho) return tuple(self.hidden) @@ -45,7 +45,7 @@ def forward(self, input): def new_like(self, new_input_size=None): if new_input_size is None: new_input_size = self.input_size - + return type(self)( new_input_size, self.hidden_size, @@ -66,7 +66,7 @@ def mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None): return state(igates, hgates, hidden[1], b_ih, b_hh) hx, cx = hidden - + m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh) gates = F.linear(input, w_ih, b_ih) + F.linear(m, w_hh, b_hh) @@ -76,9 +76,9 @@ def mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None): forgetgate = F.sigmoid(forgetgate) cellgate = F.tanh(cellgate) outgate = F.sigmoid(outgate) - + cy = (forgetgate * cx) + (ingate * cellgate) hy = outgate * F.tanh(cy) - + return hy, cy - + diff --git a/apex/RNN/models.py b/apex/RNN/models.py index dd7adce04..bd81add9a 100644 --- a/apex/RNN/models.py +++ b/apex/RNN/models.py @@ -43,7 +43,7 @@ def Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, drop """ inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size) return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) - + def mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): """ :class:`mLSTM` diff --git a/apex/amp/_process_optimizer.py b/apex/amp/_process_optimizer.py index 66c4c3fdf..ad71d3f9c 100644 --- a/apex/amp/_process_optimizer.py +++ b/apex/amp/_process_optimizer.py @@ -99,7 +99,7 @@ def post_backward_models_are_masters(scaler, params, stashed_grads, scale_overri for i in range(len(stashed_grads)): stashed_grads[i] = None return - + if scale_override is not None: grads_have_scale, stashed_have_scale, out_scale = scale_override diff --git a/apex/amp/frontend.py b/apex/amp/frontend.py index 5ee96b778..3083d9ca1 100644 --- a/apex/amp/frontend.py +++ b/apex/amp/frontend.py @@ -450,7 +450,7 @@ def load_state_dict(state_dict): len(state_dict), len(_amp_state.loss_scalers))) state_dict = state_dict.copy() - + nb_loss_scalers = len(_amp_state.loss_scalers) unexpected_keys = [] # Initialize idx outside, since unexpected_keys will increase it if enumerate is used diff --git a/apex/amp/lists/functional_overrides.py b/apex/amp/lists/functional_overrides.py index 9ecdf0972..83a3e845c 100644 --- a/apex/amp/lists/functional_overrides.py +++ b/apex/amp/lists/functional_overrides.py @@ -49,7 +49,7 @@ 'log_softmax', 'softmax', 'gelu', - + # Normalization 'layer_norm', 'group_norm', diff --git a/apex/contrib/bottleneck/bottleneck.py b/apex/contrib/bottleneck/bottleneck.py index 8e98fc3c6..edf7bb7fd 100644 --- a/apex/contrib/bottleneck/bottleneck.py +++ b/apex/contrib/bottleneck/bottleneck.py @@ -400,7 +400,7 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s torch.cuda.current_stream().wait_stream(stream2) if spatial_group_rank > 0: torch.cuda.current_stream().wait_stream(stream1) - + fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs) # save halos for backward pass if spatial_group_size > 1: @@ -609,7 +609,7 @@ class SpatialBottleneck(torch.nn.Module): # here we put it at 1x1 def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1, - dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False, + dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False, spatial_parallel_args=None): super(SpatialBottleneck, self).__init__() if groups != 1: @@ -704,7 +704,7 @@ def forward(self, x): N,C,H,W = list(x.shape) self.thresholdTop = torch.tensor([1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device='cuda') self.thresholdBottom = torch.tensor([H-2 if spatial_group_rank < spatial_group_size - 1 else H-1], dtype=torch.int32, device='cuda') - + if self.w_scale is None: # calculate scale/bias from registered buffers # TODO: make this better diff --git a/apex/contrib/conv_bias_relu/__init__.py b/apex/contrib/conv_bias_relu/__init__.py index a257106e5..cee6ae79a 100644 --- a/apex/contrib/conv_bias_relu/__init__.py +++ b/apex/contrib/conv_bias_relu/__init__.py @@ -1,2 +1,2 @@ -from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU +from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU diff --git a/apex/contrib/csrc/bottleneck/bottleneck.cpp b/apex/contrib/csrc/bottleneck/bottleneck.cpp index 9a0c3403d..bdd5b31d7 100644 --- a/apex/contrib/csrc/bottleneck/bottleneck.cpp +++ b/apex/contrib/csrc/bottleneck/bottleneck.cpp @@ -3118,7 +3118,7 @@ at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_ at::Half* w = inputs[2].data_ptr(); at::Half* z = inputs[5].data_ptr(); at::Half* b = inputs[8].data_ptr(); - + at::Half* y1 = fat_halo_y1.data_ptr(); auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format); @@ -3824,7 +3824,7 @@ void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector< // wgrad auto wgrad2 = outputs[2]; at::Half* dw2 = wgrad2.data_ptr(); - + //printf("outdimA1 = (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]); run_dconv(backward_state.outdimA1, backward_state.padA1, diff --git a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp index 66f89ef00..b46812c0d 100644 --- a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp +++ b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp @@ -773,7 +773,7 @@ run_drelu_dbias(int64_t* dy_dim, .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, dyTensor.describe()); - + generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); auto rTensor = cudnn_frontend::TensorBuilder() .setDim(4, dy_dim) @@ -783,7 +783,7 @@ run_drelu_dbias(int64_t* dy_dim, .setDataType(dataType) .build(); DEBUG_CUDNN_MSG(log_buf, rTensor.describe()); - + generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); auto inActGradTensor = cudnn_frontend::TensorBuilder() .setDim(4, dy_dim) @@ -1424,7 +1424,7 @@ std::vector conv_bias_relu_backward(std::vector inputs, w_dim[dim] = inputs[1].size(axis[dim]); y_dim[dim] = inputs[3].size(axis[dim]); } - + int64_t b_dim[] = {1, y_dim[1], 1, 1}; int64_t conv_pad[] = {padding, padding}; @@ -1450,7 +1450,7 @@ std::vector conv_bias_relu_backward(std::vector inputs, // conv wgrad at::Half* x = inputs[0].data_ptr(); auto wgrad = at::empty_like(inputs[1]); - at::Half* dw = wgrad.data_ptr(); + at::Half* dw = wgrad.data_ptr(); run_dconv(x_dim, w_dim, y_dim, @@ -1572,7 +1572,7 @@ std::vector conv_bias_backward(std::vector inputs, int64 w_dim[dim] = inputs[1].size(axis[dim]); y_dim[dim] = inputs[2].size(axis[dim]); } - + int64_t b_dim[] = {1, y_dim[1], 1, 1}; int64_t conv_pad[] = {padding, padding}; @@ -1588,12 +1588,12 @@ std::vector conv_bias_backward(std::vector inputs, int64 run_dbias(y_dim, CUDNN_DATA_HALF, dy, - db); - + db); + // conv wgrad at::Half* x = inputs[0].data_ptr(); auto wgrad = at::empty_like(inputs[1]); - at::Half* dw = wgrad.data_ptr(); + at::Half* dw = wgrad.data_ptr(); run_dconv(x_dim, w_dim, y_dim, diff --git a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp index 7668053e2..77b732214 100644 --- a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp +++ b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp @@ -178,7 +178,7 @@ static std::vector conv_bias_forward_dispatch(const at::Tensor& x, } std::string get_cache_key(const at::Tensor& x, const at::Tensor& w, int64_t padding, int64_t stride, bool relu) { - return std::to_string(x.size(0)) + "_" + std::to_string(x.size(1)) + "_" + + return std::to_string(x.size(0)) + "_" + std::to_string(x.size(1)) + "_" + std::to_string(x.size(2)) + "_" + std::to_string(x.size(3)) + "_" + std::to_string(w.size(0)) + "_" + std::to_string(w.size(1)) + "_" + std::to_string(w.size(2)) + "_" + std::to_string(w.size(3)) + "_" + @@ -255,7 +255,7 @@ static std::vector conv_bias_relu_forward_fused(const at::Tensor& x, MIOPEN_CHECK(miopenCreateOpActivationForward(plan, &activ_op, miopenActivationRELU)); }else { - MIOPEN_CHECK(miopenCreateOpActivationForward(plan, &activ_op, miopenActivationCLAMP)); + MIOPEN_CHECK(miopenCreateOpActivationForward(plan, &activ_op, miopenActivationCLAMP)); } // Compile @@ -345,9 +345,9 @@ std::vector conv_bias_relu_backward(std::vector inputs, auto grad_relu = grad_output * (out > 0).to(grad_output.dtype()); int64_t bias_size = weight.size(0); std::vector bias_sizes = {bias_size}; - auto grads = at::convolution_backward(grad_relu, x, weight, + auto grads = at::convolution_backward(grad_relu, x, weight, bias_sizes, - {stride, stride}, {padding, padding}, {1, 1}, + {stride, stride}, {padding, padding}, {1, 1}, false, {0, 0}, 1, {true, true, true}); return {std::get<0>(grads), std::get<1>(grads), std::get<2>(grads)}; @@ -366,10 +366,10 @@ std::vector conv_bias_backward(std::vector inputs, int64 auto grad_output = inputs[2]; int64_t bias_size = weight.size(0); std::vector bias_sizes = {bias_size}; - + auto grads = at::convolution_backward(grad_output, x, weight, bias_sizes, - {stride, stride}, {padding, padding}, {1, 1}, + {stride, stride}, {padding, padding}, {1, 1}, false, {0, 0}, 1, {true, true, true}); return {std::get<0>(grads), std::get<1>(grads), std::get<2>(grads)}; diff --git a/apex/contrib/csrc/fmha/fmha_api.cpp b/apex/contrib/csrc/fmha/fmha_api.cpp index 07865b6b0..dc5d0e8a9 100644 --- a/apex/contrib/csrc/fmha/fmha_api.cpp +++ b/apex/contrib/csrc/fmha/fmha_api.cpp @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -83,7 +83,7 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms, set_alpha(params.scale_dropout, params.rp_dropout, data_type); } -std::vector +std::vector mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens, // b+1 const float p_dropout, @@ -290,7 +290,7 @@ std::vector mha_bwd_nl(const at::Tensor &dout, // total x num TORCH_CHECK(sizes[THREE_DIM] == 3); const int batch_size = cu_seqlens.numel() - 1; - + const int total = sizes[TOTAL_DIM]; const int num_heads = sizes[H_DIM]; const int head_size = sizes[D_DIM]; @@ -307,7 +307,7 @@ std::vector mha_bwd_nl(const at::Tensor &dout, // total x num if( zero_tensors ) { dqkv.zero_(); } - + int num_chunks = 2; if( batch_size == 1 ) { num_chunks = 4; @@ -354,7 +354,7 @@ std::vector mha_bwd_nl(const at::Tensor &dout, // total x num } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "Fused Multi-head Self-attention for BERT"; + m.doc() = "Fused Multi-head Self-attention for BERT"; m.def("fwd", &mha_fwd, "Forward pass"); m.def("bwd", &mha_bwd, "Backward pass"); m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)"); diff --git a/apex/contrib/csrc/fmha/src/fmha.h b/apex/contrib/csrc/fmha/src/fmha.h index d01a91505..799a93d7d 100644 --- a/apex/contrib/csrc/fmha/src/fmha.h +++ b/apex/contrib/csrc/fmha/src/fmha.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -104,12 +104,12 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Launch_params{ Launch_params(cudaDeviceProp * props_, cudaStream_t stream_, bool is_training_, - bool is_nl_) + bool is_nl_) : elts_per_thread(0) , props(props_) , stream(stream_) @@ -147,7 +147,7 @@ void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_param void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); -void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const bool is_training, const int num_chunks, cudaStream_t stream); +void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const bool is_training, const int num_chunks, cudaStream_t stream); void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const int num_chunks, cudaStream_t stream); diff --git a/apex/contrib/csrc/fmha/src/fmha/gemm.h b/apex/contrib/csrc/fmha/src/fmha/gemm.h index 62529a2c5..722e78444 100644 --- a/apex/contrib/csrc/fmha/src/fmha/gemm.h +++ b/apex/contrib/csrc/fmha/src/fmha/gemm.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE diff --git a/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h b/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h index 5c86dd84e..020dde0c6 100644 --- a/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h +++ b/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -280,7 +280,7 @@ struct Gmem_tile_mma_sd { // Ctor. template - inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx) + inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx) : ptr_(static_cast(ptr)) { // The block index. @@ -328,7 +328,7 @@ struct Gmem_tile_mma_s : public Base { // Ctor. template< typename Params, typename Block_info > - inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info& binfo, const int tidx) + inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info& binfo, const int tidx) : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) { } @@ -433,7 +433,7 @@ struct Gmem_tile_dq : public Base { // Ctor. template - inline __device__ Gmem_tile_dq(const Params ¶ms, const BInfo &binfo, int tidx) + inline __device__ Gmem_tile_dq(const Params ¶ms, const BInfo &binfo, int tidx) : Base(params, binfo, tidx) { this->o_ptr_ = reinterpret_cast(params.dqkv_ptr); this->params_o_stride_in_bytes_ = params.qkv_stride_in_bytes; // needed for move diff --git a/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h b/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h index d51b47c53..d59d8c5fb 100644 --- a/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h +++ b/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE diff --git a/apex/contrib/csrc/fmha/src/fmha/mask.h b/apex/contrib/csrc/fmha/src/fmha/mask.h index 020258a02..58ff0223b 100644 --- a/apex/contrib/csrc/fmha/src/fmha/mask.h +++ b/apex/contrib/csrc/fmha/src/fmha/mask.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE diff --git a/apex/contrib/csrc/fmha/src/fmha/smem_tile.h b/apex/contrib/csrc/fmha/src/fmha/smem_tile.h index 80879140a..c3188d3e9 100644 --- a/apex/contrib/csrc/fmha/src/fmha/smem_tile.h +++ b/apex/contrib/csrc/fmha/src/fmha/smem_tile.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -34,24 +34,24 @@ namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template< // The description of the tile computed by this CTA. - typename Cta_tile, + typename Cta_tile, // The number of rows in the 2D shared memory buffer. - int M_, + int M_, // The number of cols. - int N_, + int N_, // The size in bits of each element. - int BITS_PER_ELEMENT_, + int BITS_PER_ELEMENT_, // The number of bytes per STS. int BYTES_PER_STS_ = 16, // The number of buffers. (Used in multistage and double buffer cases.) int BUFFERS_PER_TILE_ = 1, // Do we enable the fast path for LDS.128 and friends. - int ENABLE_LDS_FAST_PATH_ = 0, - // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. + int ENABLE_LDS_FAST_PATH_ = 0, + // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. int ROWS_PER_XOR_PATTERN_ = 8, - // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. + // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. int COLS_PER_XOR_PATTERN_ = 1, // Use or not predicates bool USE_PREDICATES_ = true @@ -65,7 +65,7 @@ struct Smem_tile_without_skews { // The number of elements per STS. enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; // To support arbitrary N, we pad some values to a power-of-2. - enum { N_WITH_PADDING = Next_power_of_two::VALUE }; + enum { N_WITH_PADDING = Next_power_of_two::VALUE }; // The number of bytes per row without packing of rows. enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 }; // The number of bytes per row -- we want at least 128B per row. @@ -93,7 +93,7 @@ struct Smem_tile_without_skews { // The size of one buffer in bytes in shared memory. enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA }; - // The number of buffers. + // The number of buffers. enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; // The size in bytes of total buffers. enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; @@ -103,9 +103,9 @@ struct Smem_tile_without_skews { // Do we enable the LDS.128 fast path? enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ }; static_assert(ENABLE_LDS_FAST_PATH == 0); - // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. + // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ }; - // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. + // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS }; // Use or not predicates enum { USE_PREDICATES = USE_PREDICATES_ }; @@ -114,7 +114,7 @@ struct Smem_tile_without_skews { using Store_type = typename Uint_from_size_in_bytes::Type; // Ctor. - inline __device__ Smem_tile_without_skews(void *smem, int tidx) + inline __device__ Smem_tile_without_skews(void *smem, int tidx) : smem_(__nvvm_get_smem_pointer(smem)) { // The row written by a thread. See doc/mma_smem_layout.xlsx. @@ -147,7 +147,7 @@ struct Smem_tile_without_skews { // Take the column into account. if( STS_PER_ROW > 1 ) { - offset += col*THREADS_PER_ROW*BYTES_PER_STS; + offset += col*THREADS_PER_ROW*BYTES_PER_STS; } // Apply the XOR pattern if needed. @@ -266,7 +266,7 @@ struct Smem_tile_without_skews { // Store to the tile in shared memory. template< int N > - inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) { + inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) { this->store(data, preds); } @@ -291,11 +291,11 @@ struct Smem_tile_without_skews { //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template< // The dimensions of the tile computed by the CTA. - typename Cta_tile, + typename Cta_tile, // The layout of the tile. - typename Layout, + typename Layout, // The size of the STS. int BYTES_PER_STS = 16, // The number of buffers per tile. @@ -497,11 +497,11 @@ struct Smem_tile_a //////////////////////////////////////////////////////////////////////////////////////////////////// -template< +template< // The dimensions of the tile computed by the CTA. - typename Cta_tile, + typename Cta_tile, // The layout of the tile. - typename Layout, + typename Layout, // The size of the STS. int BYTES_PER_STS = 16, // The number of buffers per tile. @@ -1217,7 +1217,7 @@ struct Smem_tile_mma_epilogue : public Base { enum { WARPS_M = Base::WARPS_M }; enum { WARPS_N = Base::WARPS_N }; static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); - + using Acc = fmha::Fragment_accumulator; inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) { @@ -1255,7 +1255,7 @@ struct Smem_tile_mma_epilogue : public Base { uint32_t y = fmha::float2_to_half2(tmp02, tmp03); uint32_t z = fmha::float2_to_half2(tmp10, tmp11); uint32_t w = fmha::float2_to_half2(tmp12, tmp13); - + size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x); fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z); diff --git a/apex/contrib/csrc/fmha/src/fmha/softmax.h b/apex/contrib/csrc/fmha/src/fmha/softmax.h index 153f42d57..92fb427cb 100644 --- a/apex/contrib/csrc/fmha/src/fmha/softmax.h +++ b/apex/contrib/csrc/fmha/src/fmha/softmax.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -278,7 +278,7 @@ struct Softmax : public Softmax_base { template inline __device__ Softmax(const Params ¶ms, void *smem, int bidb, int tidx) : Base(params, smem, bidb, tidx) - , params_scale_bmm1_(params.scale_bmm1) + , params_scale_bmm1_(params.scale_bmm1) , smem_sum_(static_cast(smem), tidx) , smem_max_(static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) { } @@ -374,12 +374,12 @@ struct Softmax : public Softmax_base { quad_allreduce(frag, tmp, op); } - __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){ + __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){ MaxOp max; reduce_(frag, max, smem_max_); } - __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){ + __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){ SumOp sum; reduce_(frag, sum, smem_sum_); } diff --git a/apex/contrib/csrc/fmha/src/fmha/utils.h b/apex/contrib/csrc/fmha/src/fmha/utils.h index bedba0eff..d060aa2d9 100644 --- a/apex/contrib/csrc/fmha/src/fmha/utils.h +++ b/apex/contrib/csrc/fmha/src/fmha/utils.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -39,7 +39,7 @@ namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -struct Row {}; +struct Row {}; struct Col {}; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -309,7 +309,7 @@ static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) { "{\n" \ "\t .reg .f16x2 sela;\n" \ "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ - "\t and.b32 %0, sela, %1;\n" + "\t and.b32 %0, sela, %1;\n" "}\n" : "=r"(res) : "r"(x), "r"(zero)); #endif return res; @@ -402,8 +402,8 @@ static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) static inline __device__ uint32_t h0_h0(uint32_t x) { uint32_t y; - asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" - : "=r"(y) : "r"(x)); + asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" + : "=r"(y) : "r"(x)); return y; } @@ -423,8 +423,8 @@ static inline __device__ float h0_to_float(uint32_t h2) { static inline __device__ uint32_t h1_h1(uint32_t x) { uint32_t y; - asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" - : "=r"(y) : "r"(x)); + asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" + : "=r"(y) : "r"(x)); return y; } @@ -979,9 +979,9 @@ struct Allreduce { template<> struct Allreduce<2> { -template +template static __device__ inline T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } }; diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu index 517a5b758..04d94c81f 100644 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu index ac22a1629..aa2fc010a 100644 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu index 7081438e9..1837cedab 100644 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu index 735006cc2..8e89f60f7 100644 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -84,7 +84,7 @@ void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_pa constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>; - + if( num_chunks == 2 ) { kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>; }else if( num_chunks == 3 ) { diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h b/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h index 3c4b81742..51a0a3031 100644 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h +++ b/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h b/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h index 26776d484..1bb120440 100644 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h +++ b/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -70,10 +70,10 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) { using Smem_tile_v = typename Kernel_traits::Smem_tile_v; // The global memory tile to store dV. - using Gmem_tile_dv = fmha::Gmem_tile_qkv; // The shared memory tile to swizzle dV. @@ -351,10 +351,10 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) { using Smem_tile_o = typename Kernel_traits::Smem_tile_o; // The global memory tile to store dK. - using Gmem_tile_dk = fmha::Gmem_tile_qkv; // The shared memory tile to swizzle dK. @@ -368,7 +368,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) { // The global memory tile to load dP, stored in S using Gmem_tile_s = Gmem_tile_mma_s; // The shared memory tile to transpose dP. - using Smem_tile_st = Smem_tile_mma_transposed; + using Smem_tile_st = Smem_tile_mma_transposed; using Noloop = Noloop_traits; @@ -551,9 +551,9 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) { // Epilogue swizzle for dK Smem_tile_dk smem_dk(&smem_[0], tidx); smem_dk.store(acc_dk); - + __syncthreads(); - + uint4 dk_out[Smem_tile_dk::NUM_LDS]; smem_dk.load(dk_out); Qkv_params dk_params; diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu index 9ebcbc59c..ead6ce980 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -31,7 +31,7 @@ using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; template -__global__ +__global__ void fmha_fprop_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params, const int num_full_heads, const int num_main_groups, @@ -61,10 +61,10 @@ void run_fmha_fp16_128_64_sm80(Launch_params(total_ctas, heads_total); return; } @@ -72,10 +72,10 @@ void run_fmha_fp16_128_64_sm80(Launch_params>>( launch_params.params, - launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, + launch_params.num_full_heads, + launch_params.num_main_groups, + launch_params.heads_last_wave, + launch_params.main_steps, launch_params.rest_steps); FMHA_CHECK_CUDA(cudaPeekAtLastError()); diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu index 448b9ad94..45d0211de 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -31,7 +31,7 @@ using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; template -__global__ +__global__ void fmha_fprop_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params, const int num_full_heads, const int num_main_groups, @@ -61,10 +61,10 @@ void run_fmha_fp16_256_64_sm80(Launch_params(total_ctas, heads_total); return; } @@ -72,10 +72,10 @@ void run_fmha_fp16_256_64_sm80(Launch_params>>( launch_params.params, - launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, + launch_params.num_full_heads, + launch_params.num_main_groups, + launch_params.heads_last_wave, + launch_params.main_steps, launch_params.rest_steps); FMHA_CHECK_CUDA(cudaPeekAtLastError()); diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu index f1f21dc32..a4bbae306 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -31,7 +31,7 @@ using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 4, 0x18u>; template -__global__ +__global__ void fmha_fprop_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params, const int num_full_heads, const int num_main_groups, @@ -61,10 +61,10 @@ void run_fmha_fp16_384_64_sm80(Launch_params(total_ctas, heads_total); return; } @@ -72,10 +72,10 @@ void run_fmha_fp16_384_64_sm80(Launch_params>>( launch_params.params, - launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, + launch_params.num_full_heads, + launch_params.num_main_groups, + launch_params.heads_last_wave, + launch_params.main_steps, launch_params.rest_steps); FMHA_CHECK_CUDA(cudaPeekAtLastError()); diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu index e37689e8c..4b6986ca5 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu +++ b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -31,7 +31,7 @@ using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x00u>; template -__global__ +__global__ void fmha_fprop_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params, const int total_heads) { @@ -39,7 +39,7 @@ void fmha_fprop_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params p } template -__global__ +__global__ void fmha_fprop_fp16_512_64_sm80_kernel_nl(Fused_multihead_attention_fprop_params params, const int num_full_heads, const int num_main_groups, @@ -107,10 +107,10 @@ void run_fmha_fp16_512_64_sm80_nl_(Launch_params(total_ctas, heads_total); return; } @@ -118,10 +118,10 @@ void run_fmha_fp16_512_64_sm80_nl_(Launch_params>>( launch_params.params, - launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, + launch_params.num_full_heads, + launch_params.num_main_groups, + launch_params.heads_last_wave, + launch_params.main_steps, launch_params.rest_steps); FMHA_CHECK_CUDA(cudaPeekAtLastError()); diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h b/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h index 5a040cf8f..f9d47c7f6 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h +++ b/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h @@ -1,6 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -51,7 +51,7 @@ struct Gemm_Q_K_base { static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2; - __device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx) + __device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx) : smem_q(smem_ptr_q, tidx) , smem_k(smem_ptr_k, tidx) { @@ -87,11 +87,11 @@ struct Gemm_Q_K : public Gemm_Q_K_base { // Q | K / V // | O | SOFTMAX - static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE + static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE + std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE, Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX); - __device__ inline Gemm_Q_K(char * smem_, const int tidx) + __device__ inline Gemm_Q_K(char * smem_, const int tidx) : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) { } @@ -146,10 +146,10 @@ struct Gemm_Q_K : public Gemm_Q_K_base { // Q | K/V + O + SOFTMAX static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE - + (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE + + (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX; - __device__ inline Gemm_Q_K(char * smem_, const int tidx) + __device__ inline Gemm_Q_K(char * smem_, const int tidx) : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) { } @@ -258,7 +258,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i Gmem_tile_v gmem_v(params, 2, binfo, tidx); // The base pointer of smem_v; char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; - + // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! Smem_tile_v smem_v(smem_v_, tidx); @@ -313,7 +313,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i __syncthreads(); } - // Load the fragments for K. + // Load the fragments for K. gemm_q_k.load_k(); // Create the object to do the softmax. @@ -461,7 +461,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ void device_1xN(const Params ¶ms, +inline __device__ void device_1xN(const Params ¶ms, const int num_full_heads, const int num_main_groups, const int main_group_size, diff --git a/apex/contrib/csrc/fmha/src/fmha_kernel.h b/apex/contrib/csrc/fmha/src/fmha_kernel.h index 63180b087..3700f9aff 100644 --- a/apex/contrib/csrc/fmha/src/fmha_kernel.h +++ b/apex/contrib/csrc/fmha/src/fmha_kernel.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -73,20 +73,20 @@ struct BlockInfoPadded { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct Noloop_traits{ // Interpretation of Cta_tile dims, i.e. Cta_tile_p: enum{ STEP = Cta_tile::M }; enum{ SEQLEN = Cta_tile::N }; template - inline __device__ Noloop_traits(const int bidc, const Block_info& binfo) + inline __device__ Noloop_traits(const int bidc, const Block_info& binfo) : bidc_(bidc) { const int seqlen = binfo.actual_seqlen; const int steps = (seqlen + STEP - 1) / STEP; const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS; - const int step_begin = bidc_ * steps_per_chunk; + const int step_begin = bidc_ * steps_per_chunk; const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk); const int actual_steps = max(0, step_end - step_begin); loop_offset_ = step_begin; @@ -94,7 +94,7 @@ struct Noloop_traits{ } - template + template inline __device__ void move_all(Tiles & ... tiles) const { using expand_type = int[]; for( int s = 0; s < loop_offset_; s++ ) { @@ -130,7 +130,7 @@ std::tuple work_dist(const int total_ctas, const constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; const int num_full_heads = heads_total / total_ctas; - const int heads_last_wave = heads_total % total_ctas; + const int heads_last_wave = heads_total % total_ctas; int num_main_groups = 0; int main_steps = 0; diff --git a/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu b/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu index 8e4b9efc3..273707e4f 100644 --- a/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu +++ b/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -51,7 +51,7 @@ __global__ __launch_bounds__(THREADS) void fmha_noloop_reduce_kernel(void *__res // The offset in bytes in dQKV to the dKV part for non-interleaved heads enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) }; - static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T)); + static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T)); // Size in bytes of the input tile enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW }; diff --git a/apex/contrib/csrc/fmha/src/fmha_utils.h b/apex/contrib/csrc/fmha/src/fmha_utils.h index de07cc78e..e2a5d2d37 100644 --- a/apex/contrib/csrc/fmha/src/fmha_utils.h +++ b/apex/contrib/csrc/fmha/src/fmha_utils.h @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE diff --git a/apex/contrib/csrc/groupbn/batch_norm.cu b/apex/contrib/csrc/groupbn/batch_norm.cu index 92eb11fbe..4e4fec47b 100644 --- a/apex/contrib/csrc/groupbn/batch_norm.cu +++ b/apex/contrib/csrc/groupbn/batch_norm.cu @@ -234,9 +234,9 @@ std::vector nhwc_bn_bwd( const float epsilon, const bool fuse_relu, void * my_data, - void * pair_data, - void * pair_data2, - void * pair_data3, + void * pair_data, + void * pair_data2, + void * pair_data3, const int bn_group, const at::Tensor& magic_tensor, const int occupancy, @@ -334,7 +334,7 @@ int nhwc_bn_fwd_occupancy() { int nhwc_bn_bwd_occupancy() { int device_id=-1; cudaGetDevice(&device_id); - + //max occupancy supported by the code is 2 return NhwcBatchNorm::smem_driven_bwd_occupancy(device_id, 2); } diff --git a/apex/contrib/csrc/groupbn/batch_norm.h b/apex/contrib/csrc/groupbn/batch_norm.h index e52751bce..94b80f70f 100644 --- a/apex/contrib/csrc/groupbn/batch_norm.h +++ b/apex/contrib/csrc/groupbn/batch_norm.h @@ -861,7 +861,7 @@ void NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, void* my_data, void* _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop); } -void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, +void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) { bool ptrs_are_set = X_tensor_desc_ != nullptr diff --git a/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu b/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu index d3cc61523..ed79ec850 100644 --- a/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu +++ b/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu @@ -243,9 +243,9 @@ std::vector nhwc_bn_addrelu_bwd( const float momentum, const float epsilon, void * my_data, - void * pair_data, - void * pair_data2, - void * pair_data3, + void * pair_data, + void * pair_data2, + void * pair_data3, const int bn_group, const at::Tensor& magic_tensor, const int occupancy, @@ -338,7 +338,7 @@ std::vector nhwc_bn_addrelu_bwd( int nhwc_bn_addrelu_fwd_occupancy() { int device_id=-1; cudaGetDevice(&device_id); - + //max occupancy supported by the code is 2 return NhwcBatchNormAddRelu::smem_driven_fwd_occupancy(device_id, 2); } diff --git a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp index b47c9daa5..c8b045b1a 100644 --- a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp +++ b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp @@ -58,7 +58,7 @@ void index_mul_2d_float_forward( at::Tensor &out, const at::Tensor &in1, const at::Tensor &in2, - const at::Tensor &idx1) + const at::Tensor &idx1) { return index_mul_2d_float_foward_cuda(out, in1, in2, idx1); } @@ -69,7 +69,7 @@ void index_mul_2d_float_backward( const at::Tensor &grad_out, const at::Tensor &in1, const at::Tensor &in2, - const at::Tensor &idx1) + const at::Tensor &idx1) { return index_mul_2d_float_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1); } @@ -92,7 +92,7 @@ void index_mul_2d_half_forward( at::Tensor &out, const at::Tensor &in1, const at::Tensor &in2, - const at::Tensor &idx1) + const at::Tensor &idx1) { return index_mul_2d_half_foward_cuda(out, in1, in2, idx1); } @@ -103,7 +103,7 @@ void index_mul_2d_half_backward( const at::Tensor &grad_out, const at::Tensor &in1, const at::Tensor &in2, - const at::Tensor &idx1) + const at::Tensor &idx1) { return index_mul_2d_half_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1); } diff --git a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu index 4f18da3bf..5fa2591c1 100644 --- a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu +++ b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu @@ -9,11 +9,11 @@ __global__ void index_mul_2d_float_dim64( - float *out, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size) + float *out, + const float *in1, + const float *in2, + const int64_t *idx1, + const int64_t size) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -24,7 +24,7 @@ __global__ void index_mul_2d_float_dim64( if (start_idx < size) { int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; - + float4 res, src1, src2; src1 = reinterpret_cast(in1)[vec_idx1]; src2 = reinterpret_cast(in2)[vec_idx2]; @@ -37,12 +37,12 @@ __global__ void index_mul_2d_float_dim64( } __global__ void index_mul_2d_float( - float *out, - const float *in1, - const float *in2, - const int64_t *idx1, + float *out, + const float *in1, + const float *in2, + const int64_t *idx1, const int64_t size, - const int64_t fea_dim) + const int64_t fea_dim) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -53,7 +53,7 @@ __global__ void index_mul_2d_float( if (start_idx < size) { int64_t vec_idx1 = (idx1[start_idx] * fea_dim); int64_t vec_idx2 = (start_idx * fea_dim); - + for (int i = tidx; i < fea_dim; i += stride) { out[vec_idx2 + i] = in1[vec_idx1 + i] * in2[vec_idx2 + i]; } @@ -61,12 +61,12 @@ __global__ void index_mul_2d_float( } __global__ void index_mul_2d_half( - at::Half *out, - const at::Half *in1, - const at::Half *in2, - const int64_t *idx1, + at::Half *out, + const at::Half *in1, + const at::Half *in2, + const int64_t *idx1, const int64_t size, - const int64_t fea_dim) + const int64_t fea_dim) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -77,7 +77,7 @@ __global__ void index_mul_2d_half( if (start_idx < size) { int64_t vec_idx1 = (idx1[start_idx] * fea_dim); int64_t vec_idx2 = (start_idx * fea_dim); - + for (int i = tidx; i < fea_dim; i += stride) { out[vec_idx2 + i] = at::Half(static_cast(in1[vec_idx1 + i]) * static_cast(in2[vec_idx2 + i])); } @@ -85,13 +85,13 @@ __global__ void index_mul_2d_half( } __global__ void index_mul_2d_grad_float_dim64( - float *grad_in1, + float *grad_in1, float *grad_in2, - const float *grad_out, + const float *grad_out, const float *in1, const float *in2, - const int64_t *idx1, - const int64_t size) + const int64_t *idx1, + const int64_t size) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -116,19 +116,19 @@ __global__ void index_mul_2d_grad_float_dim64( dst_grad_in2.y = src_grad_out.y * src_in1.y; dst_grad_in2.z = src_grad_out.z * src_in1.z; dst_grad_in2.w = src_grad_out.w * src_in1.w; - reinterpret_cast(grad_in2)[vec_idx2] = dst_grad_in2; + reinterpret_cast(grad_in2)[vec_idx2] = dst_grad_in2; } } __global__ void index_mul_2d_grad_float( - float *grad_in1, + float *grad_in1, float *grad_in2, - const float *grad_out, + const float *grad_out, const float *in1, const float *in2, - const int64_t *idx1, + const int64_t *idx1, const int64_t size, - const int64_t fea_dim) + const int64_t fea_dim) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -151,14 +151,14 @@ __global__ void index_mul_2d_grad_float( } __global__ void index_mul_2d_grad_half( - at::Half *grad_in1, + at::Half *grad_in1, at::Half *grad_in2, - const at::Half *grad_out, + const at::Half *grad_out, const at::Half *in1, const at::Half *in2, - const int64_t *idx1, + const int64_t *idx1, const int64_t size, - const int64_t fea_dim) + const int64_t fea_dim) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -190,7 +190,7 @@ __global__ void index_mul_2d_grad_grad_float_dim64( const float *in1, const float *in2, const int64_t *idx1, - const int64_t size) + const int64_t size) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -198,7 +198,7 @@ __global__ void index_mul_2d_grad_grad_float_dim64( const int start_idx = bidx * blockDim.y + tidy; constexpr int fea_dim = 64; - if (start_idx < size) { + if (start_idx < size) { int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; @@ -238,15 +238,15 @@ __global__ void index_mul_2d_grad_grad_float( const float *in2, const int64_t *idx1, const int64_t size, - const int64_t fea_dim) + const int64_t fea_dim) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; const int bidx = blockIdx.x; const int start_idx = bidx * blockDim.y + tidy; const int stride = blockDim.x; - - if (start_idx < size) { + + if (start_idx < size) { int64_t vec_idx1 = idx1[start_idx] * fea_dim; int64_t vec_idx2 = start_idx * fea_dim; @@ -274,15 +274,15 @@ __global__ void index_mul_2d_grad_grad_half( const at::Half *in2, const int64_t *idx1, const int64_t size, - const int64_t fea_dim) + const int64_t fea_dim) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; const int bidx = blockIdx.x; const int start_idx = bidx * blockDim.y + tidy; const int stride = blockDim.x; - - if (start_idx < size) { + + if (start_idx < size) { int64_t vec_idx1 = idx1[start_idx] * fea_dim; int64_t vec_idx2 = start_idx * fea_dim; @@ -310,7 +310,7 @@ void index_mul_2d_float_foward_cuda(at::Tensor &out, } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - + if (fea_dim == 64) { const int BLOCK_THREADS_DIMX = 16; const int BLOCK_THREADS_DIMY = 16; @@ -318,7 +318,7 @@ void index_mul_2d_float_foward_cuda(at::Tensor &out, dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); index_mul_2d_float_dim64<<>>( - out.data_ptr(), in1.data_ptr(), in2.data_ptr(), + out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); } else { const int BLOCK_THREADS_DIMX = 32; @@ -327,7 +327,7 @@ void index_mul_2d_float_foward_cuda(at::Tensor &out, dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); index_mul_2d_float<<>>( - out.data_ptr(), in1.data_ptr(), in2.data_ptr(), + out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); } @@ -355,7 +355,7 @@ void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1, dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); index_mul_2d_grad_float_dim64<<>>( - grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), + grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); AT_CUDA_CHECK(cudaGetLastError()); @@ -366,7 +366,7 @@ void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1, dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); index_mul_2d_grad_float<<>>( - grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), + grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); } } @@ -395,8 +395,8 @@ void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out, dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); index_mul_2d_grad_grad_float_dim64<<>>( - grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), - grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), + grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), + grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); } else { const int BLOCK_THREADS_DIMX = 32; @@ -405,9 +405,9 @@ void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out, dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); index_mul_2d_grad_grad_float<<>>( - grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), - grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); + grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), + grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), + in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); } AT_CUDA_CHECK(cudaGetLastError()); @@ -424,14 +424,14 @@ void index_mul_2d_half_foward_cuda(at::Tensor &out, } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - + const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMY = 8; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); index_mul_2d_half<<>>( - out.data_ptr(), in1.data_ptr(), in2.data_ptr(), + out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); AT_CUDA_CHECK(cudaGetLastError()); @@ -457,7 +457,7 @@ void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1, dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); index_mul_2d_grad_half<<>>( - grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), + grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); } @@ -484,9 +484,9 @@ void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out, dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); index_mul_2d_grad_grad_half<<>>( - grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), - grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); + grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), + grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), + in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/apex/contrib/csrc/layer_norm/ln.h b/apex/contrib/csrc/layer_norm/ln.h index 07392a192..9436dc777 100644 --- a/apex/contrib/csrc/layer_norm/ln.h +++ b/apex/contrib/csrc/layer_norm/ln.h @@ -8,7 +8,7 @@ namespace layer_norm { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct LaunchParams{ size_t workspace_bytes; diff --git a/apex/contrib/csrc/layer_norm/ln_api.cpp b/apex/contrib/csrc/layer_norm/ln_api.cpp index 30e4a5fec..575a96856 100644 --- a/apex/contrib/csrc/layer_norm/ln_api.cpp +++ b/apex/contrib/csrc/layer_norm/ln_api.cpp @@ -7,13 +7,13 @@ Supported Type combinations: -input compute weights output +input compute weights output ======================================= -fp32 fp32 fp32 fp32 -fp16 fp32 fp16 fp16 -bf16 fp32 bf16 bf16 -fp32 fp32 fp16 fp16 -fp32 fp32 bf16 bf16 +fp32 fp32 fp32 fp32 +fp16 fp32 fp16 fp16 +bf16 fp32 bf16 bf16 +fp32 fp32 fp16 fp16 +fp32 fp32 bf16 bf16 Remarks: Output type = Weight type @@ -240,7 +240,7 @@ std::vector ln_bwd(const at::Tensor &dz, // BxSxhidden_size //////////////////////////////////////////////////////////////////////////////////////////////////// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "CUDA LayerNorm"; + m.doc() = "CUDA LayerNorm"; m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel"); m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel"); } diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh index 8595f5ed4..2ced0c492 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh +++ b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh @@ -3,7 +3,7 @@ namespace layer_norm { template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(layer_norm::BwdParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; @@ -262,7 +262,7 @@ void ln_bwd_finalize_kernel(BwdParams params) memset(&dbeta_local, 0, sizeof(dbeta_local)); memset(&dgamma_local, 0, sizeof(dgamma_local)); - // Load beta and gamma transposed + // Load beta and gamma transposed if(read_row < Kernel_traits::ROWS_PER_CTA){ dbeta_local.load_from(smem_beta, read_idx); dgamma_local.load_from(smem_gamma, read_idx); diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu index 3893d4e0c..daf7f2215 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -11,10 +11,10 @@ template< typename output_t, typename compute_t, typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, + int HIDDEN_SIZE, + int CTAS_PER_ROW, + int WARPS_M, + int WARPS_N, int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL > @@ -42,9 +42,9 @@ void launch_(LaunchParams &launch_params, const bool configure_params launch_params.workspace_bytes = 0; if(Kernel_traits::CTAS_PER_ROW > 1) { launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW * sizeof(typename Kernel_traits::reduce_t) * 2; } diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu index 62ff78ee3..71d657091 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu @@ -11,10 +11,10 @@ template< typename output_t, typename compute_t, typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, + int HIDDEN_SIZE, + int CTAS_PER_ROW, + int WARPS_M, + int WARPS_N, int BYTES_PER_LDG > void launch_(LaunchParams &launch_params, const bool configure_params){ @@ -41,9 +41,9 @@ void launch_(LaunchParams &launch_params, const bool configure_params launch_params.workspace_bytes = 0; if(Kernel_traits::CTAS_PER_ROW > 1) { launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW * sizeof(typename Kernel_traits::Stats::stats_t) * 2; } diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh index 64e72974f..ccf8ea78b 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh +++ b/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh @@ -5,7 +5,7 @@ namespace layer_norm { template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(FwdParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; diff --git a/apex/contrib/csrc/layer_norm/ln_kernel_traits.h b/apex/contrib/csrc/layer_norm/ln_kernel_traits.h index ed745c5ee..e88ecb09e 100644 --- a/apex/contrib/csrc/layer_norm/ln_kernel_traits.h +++ b/apex/contrib/csrc/layer_norm/ln_kernel_traits.h @@ -48,7 +48,7 @@ template< struct Kernel_traits_finalize : public Base { enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP); - // Bytes per global load from the input. + // Bytes per global load from the input. enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; // Number of elements fetched by a global load. enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; @@ -64,7 +64,7 @@ struct Kernel_traits_finalize : public Base { enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; // Shared memory size to coalsece the CTA result. enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; - // Shared memory requirement per CTA. + // Shared memory requirement per CTA. enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT }; // The type of the reducer. @@ -73,7 +73,7 @@ struct Kernel_traits_finalize : public Base { // Condition for the whole CTA to participate in syncthreads. static_assert(COLS % Base::THREADS_PER_WARP == 0); enum { CTAS = COLS / Base::THREADS_PER_WARP }; -}; +}; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -84,18 +84,18 @@ template< typename output_t_, typename compute_t_, typename index_t_, - uint32_t HIDDEN_SIZE_, - uint32_t CTAS_PER_ROW_, - uint32_t WARPS_M_, - uint32_t WARPS_N_, + uint32_t HIDDEN_SIZE_, + uint32_t CTAS_PER_ROW_, + uint32_t WARPS_M_, + uint32_t WARPS_N_, uint32_t BYTES_PER_LDG_ = 16, typename Base = Kernel_traits_base< HIDDEN_SIZE_, - weight_t_, - input_t_, - output_t_, - compute_t_, - index_t_, + weight_t_, + input_t_, + output_t_, + compute_t_, + index_t_, WARPS_M_*WARPS_N_*THREADS_PER_WARP > > @@ -126,7 +126,7 @@ struct Kernel_traits : public Base { static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); using reduce_t = typename layer_norm::TypeToVec2::Type; - using Reducer = layer_norm::Reducer; + using Reducer = layer_norm::Reducer; enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; diff --git a/apex/contrib/csrc/layer_norm/ln_utils.cuh b/apex/contrib/csrc/layer_norm/ln_utils.cuh index e18d36de7..7d66c82fb 100644 --- a/apex/contrib/csrc/layer_norm/ln_utils.cuh +++ b/apex/contrib/csrc/layer_norm/ln_utils.cuh @@ -268,7 +268,7 @@ struct Zeros{ } }; -template<> +template<> struct Zeros{ static inline __device__ float2 get() { return make_float2(0.f, 0.f); @@ -380,7 +380,7 @@ struct Reducer : public Reducer { template inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) , inter_cta_(params, bidm, bidn) , bidn_(bidn) // CTA id within the group. , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) @@ -428,7 +428,7 @@ struct Reducer { enum { THREADS_PER_WARP = 32 }; template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) : warp_n_(warp_n) , lane_(lane) { @@ -454,7 +454,7 @@ struct Reducer { #pragma unroll for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) { data = op(data, warp_shuffle_down(data, it)); - } + } return data; } int warp_n_; @@ -476,8 +476,8 @@ struct Reducer : public Reducer { enum { THREADS_PER_WARP = 32 }; template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) , use0_(true) { smem0_ = &static_cast(smem)[warp_m * WARPS_N]; @@ -528,12 +528,12 @@ struct Reducer : public Reducer { }; //////////////////////////////////////////////////////////////////////////////////////////////////// - + template inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active){ //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); - + #pragma unroll for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { // Exchange @@ -570,7 +570,7 @@ struct Stats { enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) : inter_cta_(params, bidm, bidn) , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) , bidn_(bidn) // CTA id within the group. @@ -604,7 +604,7 @@ struct Stats { // Assume CTA group size in N less than 32, such that we can finalize with a single warp. static_assert(CTAS_PER_ROW <= 32); - // Every warp does the final reduction locally. + // Every warp does the final reduction locally. if( lane_ < CTAS_PER_ROW ) { stats_t result = workspace[lane_]; n = ELTS_PER_ROW_PER_CTA; @@ -638,7 +638,7 @@ struct Stats { enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) , use0_(true) { @@ -697,7 +697,7 @@ struct Stats { enum { SMEM_BYTES = 0 }; template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) { } diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu index 510a291b9..146197afe 100644 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu @@ -89,64 +89,64 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, // Input Linear Q Fwd TORCH_CUDABLAS_CHECK((hipblasGemmEx(handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - output_lin_q_dim, - batches_q, + output_lin_q_dim, + batches_q, embed_dim, static_cast(&alpha), static_cast(input_weights_q.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(inputs_q.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), q_lin_results_ptr, - HIP_R_16F, + HIP_R_16F, output_lin_q_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ ))); - + // Input Linear KV Fwd TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - output_lin_kv_dim, - batches_kv, + output_lin_kv_dim, + batches_kv, embed_dim, static_cast(&alpha), static_cast(input_weights_kv.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(inputs_kv.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), k_lin_results_ptr, - HIP_R_16F, + HIP_R_16F, output_lin_kv_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( a_layout_t, - b_layout_n, + gemm_switch_fp32accum( a_layout_t, + b_layout_n, k_seq_len, q_seq_len, head_dim, - scale, - static_cast(k_lin_results_ptr), - lead_dim_kv, - batch_stride_kv, + scale, + static_cast(k_lin_results_ptr), + lead_dim_kv, + batch_stride_kv, static_cast(q_lin_results_ptr), - lead_dim_q, - batch_stride_q, - beta, - static_cast(softmax_results_ptr), - k_seq_len, + lead_dim_q, + batch_stride_q, + beta, + static_cast(softmax_results_ptr), + k_seq_len, k_seq_len*q_seq_len, attn_batches ); @@ -183,42 +183,42 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, } // Matmul2 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - alpha, - static_cast(v_lin_results_ptr), - lead_dim_kv, - batch_stride_kv, - (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , - k_seq_len, - k_seq_len*q_seq_len, - beta, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + alpha, + static_cast(v_lin_results_ptr), + lead_dim_kv, + batch_stride_kv, + (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , + k_seq_len, + k_seq_len*q_seq_len, + beta, + static_cast(matmul2_results.data_ptr()), + head_dim*attn_batches, + head_dim, attn_batches ); // Output Linear TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - embed_dim, - batches_q, + embed_dim, + batches_q, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(matmul2_results.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), static_cast(outputs.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ @@ -296,8 +296,8 @@ std::vector bwd_cuda( char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; - char b_layout_t{'t'}; - + char b_layout_t{'t'}; + rocblas_int flags = 0; //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); @@ -313,89 +313,89 @@ std::vector bwd_cuda( // Output Linear Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, CUBLAS_OP_N, - embed_dim, - batches_q, + CUBLAS_OP_N, + embed_dim, + batches_q, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(output_grads.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), static_cast(output_lin_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - + // Output Linear Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, embed_dim, - batches_q, + embed_dim, + batches_q, static_cast(&alpha), static_cast(matmul2_results.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(output_grads.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), static_cast(output_weight_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - + // MatMul2 Dgrad1 - gemm_switch_fp32accum( a_layout_t, - b_layout_n, + gemm_switch_fp32accum( a_layout_t, + b_layout_n, k_seq_len, q_seq_len, head_dim, - alpha, + alpha, static_cast(v_lin_results_ptr), - lead_dim_kv, + lead_dim_kv, batch_stride_kv, static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - beta, + head_dim*attn_batches, + head_dim, + beta, static_cast(matmul2_grads.data_ptr()), - k_seq_len, + k_seq_len, k_seq_len*q_seq_len, attn_batches ); - + // Matmul2 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - alpha, + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + alpha, static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, + head_dim*attn_batches, + head_dim, static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - v_lin_grads_ptr, - lead_dim_kv, - batch_stride_kv, + k_seq_len, + k_seq_len*q_seq_len, + beta, + v_lin_grads_ptr, + lead_dim_kv, + batch_stride_kv, attn_batches ); - // Apply Dropout Mask and Scale by Dropout Probability + // Apply Dropout Mask and Scale by Dropout Probability apex_masked_scale_cuda( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), @@ -413,74 +413,74 @@ std::vector bwd_cuda( assert(softmax_success); // Matmul1 Dgrad1 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - scale, - k_lin_results_ptr, - lead_dim_kv, - batch_stride_kv, + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + scale, + k_lin_results_ptr, + lead_dim_kv, + batch_stride_kv, static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - q_lin_grads_ptr, - lead_dim_q, - batch_stride_q, + k_seq_len, + k_seq_len*q_seq_len, + beta, + q_lin_grads_ptr, + lead_dim_q, + batch_stride_q, attn_batches ); - + // Matmul1 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - scale, - q_lin_results_ptr, - lead_dim_q, - batch_stride_q, + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + scale, + q_lin_results_ptr, + lead_dim_q, + batch_stride_q, static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - k_lin_grads_ptr, - lead_dim_kv, - batch_stride_kv, + k_seq_len, + k_seq_len*q_seq_len, + beta, + k_lin_grads_ptr, + lead_dim_kv, + batch_stride_kv, attn_batches ); - // Input Linear Q Dgrad + // Input Linear Q Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, - batches_q, + batches_q, output_lin_q_dim, static_cast(&alpha), static_cast(input_weights_q.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(q_lin_grads_ptr), - HIP_R_16F, - output_lin_q_dim, + HIP_R_16F, + output_lin_q_dim, static_cast(&beta), static_cast(input_q_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - - // Input Linear Q Wgrad + + // Input Linear Q Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, + embed_dim, output_lin_q_dim, - batches_q, + batches_q, static_cast(&alpha), static_cast(inputs_q.data_ptr()), HIP_R_16F, @@ -490,41 +490,41 @@ std::vector bwd_cuda( output_lin_q_dim, static_cast(&beta), static_cast(input_weight_q_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - - // Input Linear KV Dgrad + + // Input Linear KV Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, - batches_kv, + batches_kv, output_lin_kv_dim, static_cast(&alpha), static_cast(input_weights_kv.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(k_lin_grads_ptr), - HIP_R_16F, - output_lin_kv_dim, + HIP_R_16F, + output_lin_kv_dim, static_cast(&beta), static_cast(input_kv_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - - // Input Linear KV Wgrad + + // Input Linear KV Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, + embed_dim, output_lin_kv_dim, - batches_kv, + batches_kv, static_cast(&alpha), static_cast(inputs_kv.data_ptr()), HIP_R_16F, @@ -534,22 +534,22 @@ std::vector bwd_cuda( output_lin_kv_dim, static_cast(&beta), static_cast(input_weight_kv_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); // TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return { - input_q_grads, - input_kv_grads, - input_weight_q_grads, - input_weight_kv_grads, + return { + input_q_grads, + input_kv_grads, + input_weight_q_grads, + input_weight_kv_grads, output_weight_grads }; } } // end namespace rocblas_gemmex -} // end namespace encdec +} // end namespace encdec } // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu index 56da36dcd..1c92b8a9b 100644 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu @@ -25,8 +25,8 @@ std::vector fwd_cuda( bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, + torch::Tensor const& inputs_q, + torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, @@ -34,7 +34,7 @@ std::vector fwd_cuda( torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob - ) + ) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); @@ -55,8 +55,8 @@ std::vector fwd_cuda( const float alpha = 1.0; const float beta = 0.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // There is no reason to use more than one stream as every kernel is + + // There is no reason to use more than one stream as every kernel is // sequentially dependent cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); @@ -116,64 +116,64 @@ std::vector fwd_cuda( // Input Linear Q Fwd TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - output_lin_q_dim, - batches_q, + output_lin_q_dim, + batches_q, embed_dim, static_cast(&alpha), static_cast(input_weights_q.data_ptr()), - HIP_R_16F /*a_type*/, + HIP_R_16F /*a_type*/, embed_dim, //static_cast(inputs_q.data_ptr()), static_cast(lyr_nrm_results.data_ptr()), - HIP_R_16F /*b_type*/, - embed_dim, + HIP_R_16F /*b_type*/, + embed_dim, static_cast(&beta), q_lin_results_ptr, - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, output_lin_q_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - + // Input Linear KV Fwd TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - output_lin_kv_dim, - batches_kv, + output_lin_kv_dim, + batches_kv, embed_dim, static_cast(&alpha), static_cast(input_weights_kv.data_ptr()), - HIP_R_16F /*a_type*/, + HIP_R_16F /*a_type*/, embed_dim, static_cast(inputs_kv.data_ptr()), - HIP_R_16F /*b_type*/, - embed_dim, + HIP_R_16F /*b_type*/, + embed_dim, static_cast(&beta), k_lin_results_ptr, - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, output_lin_kv_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( a_layout_t, - b_layout_n, + gemm_switch_fp32accum( a_layout_t, + b_layout_n, k_seq_len, q_seq_len, head_dim, - scale, - static_cast(k_lin_results_ptr), - lead_dim_kv, - batch_stride_kv, + scale, + static_cast(k_lin_results_ptr), + lead_dim_kv, + batch_stride_kv, static_cast(q_lin_results_ptr), - lead_dim_q, - batch_stride_q, - beta, - static_cast(softmax_results_ptr), - k_seq_len, + lead_dim_q, + batch_stride_q, + beta, + static_cast(softmax_results_ptr), + k_seq_len, k_seq_len*q_seq_len, attn_batches ); @@ -210,49 +210,49 @@ std::vector fwd_cuda( } // Matmul2 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - alpha, - static_cast(v_lin_results_ptr), - lead_dim_kv, - batch_stride_kv, - (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()), - //static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + alpha, + static_cast(v_lin_results_ptr), + lead_dim_kv, + batch_stride_kv, + (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()), + //static_cast(dropout_results.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + static_cast(matmul2_results.data_ptr()), + head_dim*attn_batches, head_dim, attn_batches ); // Output Linear TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - embed_dim, - batches_q, + embed_dim, + batches_q, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), - HIP_R_16F /*a_type*/, + HIP_R_16F /*a_type*/, embed_dim, static_cast(matmul2_results.data_ptr()), - HIP_R_16F /*b_type*/, - embed_dim, + HIP_R_16F /*b_type*/, + embed_dim, static_cast(&beta), static_cast(output_lin_results.data_ptr()), - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - // End-of-block Dropout-Add + // End-of-block Dropout-Add if (is_training) { apex_dropout_add_cuda( static_cast(output_lin_results.data_ptr()), @@ -353,103 +353,103 @@ std::vector bwd_cuda( char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; - char b_layout_t{'t'}; + char b_layout_t{'t'}; //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - // Dropout Add Backward + // Dropout Add Backward apex_masked_scale_cuda( static_cast(output_grads.data_ptr()), static_cast(dropout_add_grads.data_ptr()), static_cast(dropout_add_mask.data_ptr()), total_tokens_q, (1.0 / (1.0 - dropout_prob))); - + // Output Linear Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, CUBLAS_OP_N, - embed_dim, - batches_q, + CUBLAS_OP_N, + embed_dim, + batches_q, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), - HIP_R_16F /*a_type*/, + HIP_R_16F /*a_type*/, embed_dim, static_cast(dropout_add_grads.data_ptr()), - HIP_R_16F /*b_type*/, - embed_dim, + HIP_R_16F /*b_type*/, + embed_dim, static_cast(&beta), static_cast(output_lin_grads.data_ptr()), - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - + // Output Linear Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, embed_dim, - batches_q, + embed_dim, + batches_q, static_cast(&alpha), static_cast(matmul2_results.data_ptr()), - HIP_R_16F /*a_type*/, + HIP_R_16F /*a_type*/, embed_dim, static_cast(dropout_add_grads.data_ptr()), - HIP_R_16F /*b_type*/, - embed_dim, + HIP_R_16F /*b_type*/, + embed_dim, static_cast(&beta), static_cast(output_weight_grads.data_ptr()), - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - + // MatMul2 Dgrad1 - gemm_switch_fp32accum( a_layout_t, - b_layout_n, + gemm_switch_fp32accum( a_layout_t, + b_layout_n, k_seq_len, q_seq_len, head_dim, - alpha, + alpha, static_cast(v_lin_results_ptr), - lead_dim_kv, + lead_dim_kv, batch_stride_kv, static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - beta, + head_dim*attn_batches, + head_dim, + beta, static_cast(matmul2_grads.data_ptr()), - k_seq_len, + k_seq_len, k_seq_len*q_seq_len, attn_batches ); - + // Matmul2 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - alpha, + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + alpha, static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, + head_dim*attn_batches, + head_dim, static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - v_lin_grads_ptr, - lead_dim_kv, - batch_stride_kv, + k_seq_len, + k_seq_len*q_seq_len, + beta, + v_lin_grads_ptr, + lead_dim_kv, + batch_stride_kv, attn_batches ); - // Apply Dropout Mask and Scale by Dropout Probability + // Apply Dropout Mask and Scale by Dropout Probability apex_masked_scale_cuda( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), @@ -467,75 +467,75 @@ std::vector bwd_cuda( assert(softmax_success); // Matmul1 Dgrad1 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - scale, - k_lin_results_ptr, - lead_dim_kv, - batch_stride_kv, + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + scale, + k_lin_results_ptr, + lead_dim_kv, + batch_stride_kv, static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - q_lin_grads_ptr, - lead_dim_q, - batch_stride_q, + k_seq_len, + k_seq_len*q_seq_len, + beta, + q_lin_grads_ptr, + lead_dim_q, + batch_stride_q, attn_batches ); - + // Matmul1 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - scale, - q_lin_results_ptr, - lead_dim_q, - batch_stride_q, + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + scale, + q_lin_results_ptr, + lead_dim_q, + batch_stride_q, static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - k_lin_grads_ptr, - lead_dim_kv, + k_seq_len, + k_seq_len*q_seq_len, + beta, + k_lin_grads_ptr, + lead_dim_kv, batch_stride_kv, attn_batches ); - // Input Linear Q Dgrad + // Input Linear Q Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, - batches_q, + batches_q, output_lin_q_dim, static_cast(&alpha), static_cast(input_weights_q.data_ptr()), - HIP_R_16F /*a_type*/, + HIP_R_16F /*a_type*/, embed_dim, static_cast(q_lin_grads_ptr), - HIP_R_16F /*b_type*/, - output_lin_q_dim, + HIP_R_16F /*b_type*/, + output_lin_q_dim, static_cast(&beta), //static_cast(input_q_grads.data_ptr()), static_cast(input_lin_q_grads.data_ptr()), - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - - // Input Linear Q Wgrad + + // Input Linear Q Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, + embed_dim, output_lin_q_dim, - batches_q, + batches_q, static_cast(&alpha), static_cast(inputs_q.data_ptr()), HIP_R_16F /*a_type*/, @@ -545,41 +545,41 @@ std::vector bwd_cuda( output_lin_q_dim, static_cast(&beta), static_cast(input_weight_q_grads.data_ptr()), - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - - // Input Linear KV Dgrad + + // Input Linear KV Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, - batches_kv, + batches_kv, output_lin_kv_dim, static_cast(&alpha), static_cast(input_weights_kv.data_ptr()), - HIP_R_16F /*a_type*/, + HIP_R_16F /*a_type*/, embed_dim, static_cast(k_lin_grads_ptr), - HIP_R_16F /*b_type*/, - output_lin_kv_dim, + HIP_R_16F /*b_type*/, + output_lin_kv_dim, static_cast(&beta), static_cast(input_kv_grads.data_ptr()), - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - - // Input Linear KV Wgrad + + // Input Linear KV Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, + embed_dim, output_lin_kv_dim, - batches_kv, + batches_kv, static_cast(&alpha), static_cast(inputs_kv.data_ptr()), HIP_R_16F /*a_type*/, @@ -589,16 +589,16 @@ std::vector bwd_cuda( output_lin_kv_dim, static_cast(&beta), static_cast(input_weight_kv_grads.data_ptr()), - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - + // Fused Layer Norm Bwd with Residual Add HostLayerNormGradient( static_cast(input_lin_q_grads.data_ptr()), - static_cast(output_grads.data_ptr()), + static_cast(output_grads.data_ptr()), static_cast(lyr_nrm_mean.data_ptr()), static_cast(lyr_nrm_invvar.data_ptr()), inputs_q, @@ -611,7 +611,7 @@ std::vector bwd_cuda( static_cast(lyr_nrm_gamma_grads.data_ptr()), static_cast(lyr_nrm_beta_grads.data_ptr()) ); - + //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); return {input_q_grads, input_kv_grads, lyr_nrm_gamma_grads, @@ -620,5 +620,5 @@ std::vector bwd_cuda( } } // end namespace rocblas_gemmex -} // end namespace encdec_norm_add +} // end namespace encdec_norm_add } // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp index 809620e0d..0c4254a3c 100644 --- a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp +++ b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp @@ -493,7 +493,7 @@ std::vector bwd_cuda( // torch::Tensor const& input_biases, // torch::Tensor const& output_biases, torch::Tensor const &dropout_mask, float dropout_prob); - + std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, torch::Tensor const &input_weights, @@ -588,7 +588,7 @@ std::vector bwd_cuda( // torch::Tensor const& input_biases, // torch::Tensor const& output_biases, torch::Tensor const &dropout_mask, float dropout_prob); - + std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, torch::Tensor const &inputs, torch::Tensor const &input_weights, diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu index f1128da54..740b98d39 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu @@ -21,7 +21,7 @@ namespace self_bias_additive_mask { namespace rocblas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const& inputs, + int heads, torch::Tensor const& inputs, torch::Tensor const& input_weights, torch::Tensor const& output_weights, torch::Tensor const& input_biases, @@ -43,7 +43,7 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, const float beta_one = 1.0; const float scale = 1.0 / sqrt(static_cast(head_dim)); - // There is no reason to use more than one stream as every kernel is + // There is no reason to use more than one stream as every kernel is // sequentially dependent cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); @@ -84,46 +84,46 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, // Input Linear Fwd input_lin_results.copy_(input_biases); TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - output_lin_dim, - batches, + output_lin_dim, + batches, embed_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(inputs.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta_one), q_lin_results_ptr, - HIP_R_16F, + HIP_R_16F, output_lin_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - + // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( a_layout_t, - b_layout_n, + gemm_switch_fp32accum( a_layout_t, + b_layout_n, k_seq_len, q_seq_len, head_dim, - scale, - static_cast(k_lin_results_ptr), - lead_dim, - batch_stride, + scale, + static_cast(k_lin_results_ptr), + lead_dim, + batch_stride, static_cast(q_lin_results_ptr), - lead_dim, - batch_stride, - beta_zero, - static_cast(bmm1_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, + lead_dim, + batch_stride, + beta_zero, + static_cast(bmm1_results_ptr), + k_seq_len, + k_seq_len*q_seq_len, attn_batches ); - + // Padded Softmax bool softmax_success = false; if (is_training) { @@ -148,22 +148,22 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, } // Matmul2 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - alpha, - static_cast(v_lin_results_ptr), - lead_dim, - batch_stride, - static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta_zero, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + alpha, + static_cast(v_lin_results_ptr), + lead_dim, + batch_stride, + static_cast(dropout_results.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta_zero, + static_cast(matmul2_results.data_ptr()), + head_dim*attn_batches, + head_dim, attn_batches ); @@ -171,21 +171,21 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, // Output Linear TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - embed_dim, - batches, + embed_dim, + batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(matmul2_results.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta_one), static_cast(outputs.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ @@ -252,21 +252,21 @@ std::vector bwd_cuda( // Output Linear Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, CUBLAS_OP_N, - embed_dim, - batches, + CUBLAS_OP_N, + embed_dim, + batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(output_grads.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), static_cast(output_lin_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ @@ -274,21 +274,21 @@ std::vector bwd_cuda( // Output Linear Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, embed_dim, - batches, + embed_dim, + batches, static_cast(&alpha), static_cast(matmul2_results.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(output_grads.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), static_cast(output_weight_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ @@ -296,50 +296,50 @@ std::vector bwd_cuda( auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); // MatMul2 Dgrad1 - gemm_switch_fp32accum( a_layout_t, - b_layout_n, + gemm_switch_fp32accum( a_layout_t, + b_layout_n, k_seq_len, q_seq_len, head_dim, - alpha, + alpha, static_cast(v_lin_results_ptr), - lead_dim, + lead_dim, batch_stride, static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - beta, + head_dim*attn_batches, + head_dim, + beta, static_cast(matmul2_grads.data_ptr()), - k_seq_len, + k_seq_len, k_seq_len*q_seq_len, attn_batches ); - + // Matmul2 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - alpha, + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + alpha, static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, + head_dim*attn_batches, + head_dim, static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - v_lin_grads_ptr, - lead_dim, - batch_stride, + k_seq_len, + k_seq_len*q_seq_len, + beta, + v_lin_grads_ptr, + lead_dim, + batch_stride, attn_batches ); - // Apply Dropout Mask and Scale by Dropout Probability + // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad dispatch_masked_scale_softmax_backward_recompute( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), + static_cast(matmul2_grads.data_ptr()), + static_cast(matmul2_grads.data_ptr()), reinterpret_cast(bmm1_results.data_ptr()), reinterpret_cast(pad_mask.data_ptr()), static_cast(dropout_mask.data_ptr()), @@ -349,76 +349,76 @@ std::vector bwd_cuda( attn_batches*q_seq_len/sequences, attn_batches*q_seq_len, stream); - + // Matmul1 Dgrad1 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - scale, - k_lin_results_ptr, - lead_dim, - batch_stride, + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + scale, + k_lin_results_ptr, + lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - q_lin_grads_ptr, - lead_dim, - batch_stride, + k_seq_len, + k_seq_len*q_seq_len, + beta, + q_lin_grads_ptr, + lead_dim, + batch_stride, attn_batches ); - + // Matmul1 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - scale, - q_lin_results_ptr, - lead_dim, - batch_stride, + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + scale, + q_lin_results_ptr, + lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - k_lin_grads_ptr, - lead_dim, - batch_stride, + k_seq_len, + k_seq_len*q_seq_len, + beta, + k_lin_grads_ptr, + lead_dim, + batch_stride, attn_batches ); - - // Input Linear Dgrad + + // Input Linear Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, - batches, + batches, output_lin_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(input_lin_output_grads.data_ptr()), - HIP_R_16F, - output_lin_dim, + HIP_R_16F, + output_lin_dim, static_cast(&beta), static_cast(input_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - - // Input Linear Wgrad + + // Input Linear Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, + embed_dim, output_lin_dim, - batches, + batches, static_cast(&alpha), static_cast(inputs.data_ptr()), HIP_R_16F, @@ -428,7 +428,7 @@ std::vector bwd_cuda( output_lin_dim, static_cast(&beta), static_cast(input_weight_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu index 3b23ebb75..81a221fc5 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu @@ -84,46 +84,46 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, // Input Linear Fwd input_lin_results.copy_(input_biases); TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - output_lin_dim, - batches, + output_lin_dim, + batches, embed_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(inputs.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta_one), q_lin_results_ptr, - HIP_R_16F, + HIP_R_16F, output_lin_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( a_layout_t, - b_layout_n, + gemm_switch_fp32accum( a_layout_t, + b_layout_n, k_seq_len, q_seq_len, head_dim, - scale, - static_cast(k_lin_results_ptr), - lead_dim, - batch_stride, + scale, + static_cast(k_lin_results_ptr), + lead_dim, + batch_stride, static_cast(q_lin_results_ptr), - lead_dim, - batch_stride, - beta_zero, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, + lead_dim, + batch_stride, + beta_zero, + static_cast(softmax_results_ptr), + k_seq_len, + k_seq_len*q_seq_len, attn_batches ); - + // Padded Softmax bool softmax_success = false; if (pad_mask == nullptr) { @@ -156,22 +156,22 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, } // Matmul2 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - alpha, - static_cast(v_lin_results_ptr), - lead_dim, - batch_stride, - (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , - k_seq_len, - k_seq_len*q_seq_len, - beta_zero, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + alpha, + static_cast(v_lin_results_ptr), + lead_dim, + batch_stride, + (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , + k_seq_len, + k_seq_len*q_seq_len, + beta_zero, + static_cast(matmul2_results.data_ptr()), + head_dim*attn_batches, + head_dim, attn_batches ); @@ -179,21 +179,21 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, // Output Linear TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - embed_dim, - batches, + embed_dim, + batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(matmul2_results.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta_one), static_cast(outputs.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ @@ -273,21 +273,21 @@ std::vector bwd_cuda( // Output Linear Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, CUBLAS_OP_N, - embed_dim, - batches, + CUBLAS_OP_N, + embed_dim, + batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(output_grads.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), static_cast(output_lin_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ @@ -295,21 +295,21 @@ std::vector bwd_cuda( // Output Linear Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, embed_dim, - batches, + embed_dim, + batches, static_cast(&alpha), static_cast(matmul2_results.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(output_grads.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), static_cast(output_weight_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ @@ -317,46 +317,46 @@ std::vector bwd_cuda( auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); // MatMul2 Dgrad1 - gemm_switch_fp32accum( a_layout_t, - b_layout_n, + gemm_switch_fp32accum( a_layout_t, + b_layout_n, k_seq_len, q_seq_len, head_dim, - alpha, + alpha, static_cast(v_lin_results_ptr), - lead_dim, + lead_dim, batch_stride, static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - beta, + head_dim*attn_batches, + head_dim, + beta, static_cast(matmul2_grads.data_ptr()), - k_seq_len, + k_seq_len, k_seq_len*q_seq_len, attn_batches ); // Matmul2 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - alpha, + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + alpha, static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, + head_dim*attn_batches, + head_dim, static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - v_lin_grads_ptr, - lead_dim, - batch_stride, + k_seq_len, + k_seq_len*q_seq_len, + beta, + v_lin_grads_ptr, + lead_dim, + batch_stride, attn_batches ); - // Apply Dropout Mask and Scale by Dropout Probability + // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad dispatch_masked_scale_softmax_backward_stream( static_cast(matmul2_grads.data_ptr()), @@ -367,73 +367,73 @@ std::vector bwd_cuda( attn_batches * q_seq_len, stream); // Matmul1 Dgrad1 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - scale, - k_lin_results_ptr, - lead_dim, - batch_stride, + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + scale, + k_lin_results_ptr, + lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, + k_seq_len, + k_seq_len*q_seq_len, beta, - q_lin_grads_ptr, - lead_dim, - batch_stride, + q_lin_grads_ptr, + lead_dim, + batch_stride, attn_batches ); // Matmul1 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - scale, - q_lin_results_ptr, - lead_dim, - batch_stride, + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + scale, + q_lin_results_ptr, + lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - k_lin_grads_ptr, - lead_dim, + k_seq_len, + k_seq_len*q_seq_len, + beta, + k_lin_grads_ptr, + lead_dim, batch_stride, attn_batches ); - // Input Linear Dgrad + // Input Linear Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, - batches, + batches, output_lin_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(input_lin_output_grads.data_ptr()), - HIP_R_16F, - output_lin_dim, + HIP_R_16F, + output_lin_dim, static_cast(&beta), static_cast(input_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - // Input Linear Wgrad + // Input Linear Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, + embed_dim, output_lin_dim, - batches, + batches, static_cast(&alpha), static_cast(inputs.data_ptr()), HIP_R_16F, @@ -443,7 +443,7 @@ std::vector bwd_cuda( output_lin_dim, static_cast(&beta), static_cast(input_weight_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu index 35795cd85..f8cef93bb 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu @@ -80,43 +80,43 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, // Input Linear Fwd TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - output_lin_dim, - batches, + output_lin_dim, + batches, embed_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(inputs.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), q_lin_results_ptr, - HIP_R_16F, + HIP_R_16F, output_lin_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( a_layout_t, - b_layout_n, + gemm_switch_fp32accum( a_layout_t, + b_layout_n, k_seq_len, q_seq_len, head_dim, - scale, - static_cast(k_lin_results_ptr), - lead_dim, - batch_stride, + scale, + static_cast(k_lin_results_ptr), + lead_dim, + batch_stride, static_cast(q_lin_results_ptr), - lead_dim, - batch_stride, - beta, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, + lead_dim, + batch_stride, + beta, + static_cast(softmax_results_ptr), + k_seq_len, + k_seq_len*q_seq_len, attn_batches ); @@ -152,42 +152,42 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, } // Matmul2 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - alpha, - static_cast(v_lin_results_ptr), - lead_dim, - batch_stride, - (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , - k_seq_len, - k_seq_len*q_seq_len, - beta, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + alpha, + static_cast(v_lin_results_ptr), + lead_dim, + batch_stride, + (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , + k_seq_len, + k_seq_len*q_seq_len, + beta, + static_cast(matmul2_results.data_ptr()), + head_dim*attn_batches, + head_dim, attn_batches ); // Output Linear TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - embed_dim, - batches, + embed_dim, + batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(matmul2_results.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), static_cast(outputs.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ @@ -249,93 +249,93 @@ std::vector bwd_cuda( char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; - char b_layout_t{'t'}; + char b_layout_t{'t'}; // Output Linear Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, CUBLAS_OP_N, - embed_dim, - batches, + CUBLAS_OP_N, + embed_dim, + batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(output_grads.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), static_cast(output_lin_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - + // Output Linear Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, embed_dim, - batches, + embed_dim, + batches, static_cast(&alpha), static_cast(matmul2_results.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(output_grads.data_ptr()), - HIP_R_16F, - embed_dim, + HIP_R_16F, + embed_dim, static_cast(&beta), static_cast(output_weight_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - + // MatMul2 Dgrad1 - gemm_switch_fp32accum( a_layout_t, - b_layout_n, + gemm_switch_fp32accum( a_layout_t, + b_layout_n, k_seq_len, q_seq_len, head_dim, - alpha, + alpha, static_cast(v_lin_results_ptr), - lead_dim, + lead_dim, batch_stride, static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - beta, + head_dim*attn_batches, + head_dim, + beta, static_cast(matmul2_grads.data_ptr()), - k_seq_len, + k_seq_len, k_seq_len*q_seq_len, attn_batches ); - + // Matmul2 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - alpha, + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + alpha, static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, + head_dim*attn_batches, + head_dim, static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - v_lin_grads_ptr, - lead_dim, - batch_stride, + k_seq_len, + k_seq_len*q_seq_len, + beta, + v_lin_grads_ptr, + lead_dim, + batch_stride, attn_batches ); - // Apply Dropout Mask and Scale by Dropout Probability + // Apply Dropout Mask and Scale by Dropout Probability apex_masked_scale_cuda( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), @@ -353,74 +353,74 @@ std::vector bwd_cuda( assert(softmax_success); // Matmul1 Dgrad1 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - scale, - k_lin_results_ptr, - lead_dim, - batch_stride, + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + scale, + k_lin_results_ptr, + lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - q_lin_grads_ptr, - lead_dim, - batch_stride, + k_seq_len, + k_seq_len*q_seq_len, + beta, + q_lin_grads_ptr, + lead_dim, + batch_stride, attn_batches ); - + // Matmul1 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - scale, - q_lin_results_ptr, - lead_dim, - batch_stride, + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + scale, + q_lin_results_ptr, + lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - k_lin_grads_ptr, - lead_dim, + k_seq_len, + k_seq_len*q_seq_len, + beta, + k_lin_grads_ptr, + lead_dim, batch_stride, attn_batches ); - // Input Linear Dgrad + // Input Linear Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, - batches, + batches, output_lin_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, static_cast(q_lin_grads_ptr), - HIP_R_16F, - output_lin_dim, + HIP_R_16F, + output_lin_dim, static_cast(&beta), static_cast(input_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - - // Input Linear Wgrad + + // Input Linear Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, + embed_dim, output_lin_dim, - batches, + batches, static_cast(&alpha), static_cast(inputs.data_ptr()), HIP_R_16F, @@ -430,15 +430,15 @@ std::vector bwd_cuda( output_lin_dim, static_cast(&beta), static_cast(input_weight_grads.data_ptr()), - HIP_R_16F, + HIP_R_16F, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - - return { - input_grads, - input_weight_grads, + + return { + input_grads, + input_weight_grads, output_weight_grads }; } diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu index 17150aea9..8ff929274 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu @@ -102,44 +102,44 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, // Input Linear Fwd TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - output_lin_dim, - batches, + output_lin_dim, + batches, embed_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), - HIP_R_16F /*a_type*/, + HIP_R_16F /*a_type*/, embed_dim, //static_cast(inputs.data_ptr()), static_cast(lyr_nrm_results.data_ptr()), - HIP_R_16F /*b_type*/, - embed_dim, + HIP_R_16F /*b_type*/, + embed_dim, static_cast(&beta), q_lin_results_ptr, - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, output_lin_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( a_layout_t, - b_layout_n, + gemm_switch_fp32accum( a_layout_t, + b_layout_n, k_seq_len, q_seq_len, head_dim, - scale, - static_cast(k_lin_results_ptr), - lead_dim, - batch_stride, + scale, + static_cast(k_lin_results_ptr), + lead_dim, + batch_stride, static_cast(q_lin_results_ptr), - lead_dim, - batch_stride, - beta, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, + lead_dim, + batch_stride, + beta, + static_cast(softmax_results_ptr), + k_seq_len, + k_seq_len*q_seq_len, attn_batches ); @@ -175,50 +175,50 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, } // Matmul2 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - alpha, - static_cast(v_lin_results_ptr), - lead_dim, - batch_stride, - (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , - //static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + alpha, + static_cast(v_lin_results_ptr), + lead_dim, + batch_stride, + (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , + //static_cast(dropout_results.data_ptr()), + k_seq_len, + k_seq_len*q_seq_len, + beta, + static_cast(matmul2_results.data_ptr()), + head_dim*attn_batches, head_dim, attn_batches ); // Output Linear TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_T, + CUBLAS_OP_T, CUBLAS_OP_N, - embed_dim, - batches, + embed_dim, + batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), - HIP_R_16F /*a_type*/, + HIP_R_16F /*a_type*/, embed_dim, static_cast(matmul2_results.data_ptr()), - HIP_R_16F /*b_type*/, - embed_dim, + HIP_R_16F /*b_type*/, + embed_dim, static_cast(&beta), static_cast(output_lin_results.data_ptr()), - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - - // End-of-block Dropout-Add + + // End-of-block Dropout-Add if (is_training) { apex_dropout_add_cuda( static_cast(output_lin_results.data_ptr()), @@ -300,8 +300,8 @@ std::vector bwd_cuda( char a_layout_n{'n'}; char a_layout_t{'t'}; char b_layout_n{'n'}; - char b_layout_t{'t'}; - + char b_layout_t{'t'}; + // Dropout Add Backward apex_masked_scale_cuda( static_cast(output_grads.data_ptr()), @@ -311,89 +311,89 @@ std::vector bwd_cuda( // Output Linear Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, CUBLAS_OP_N, - embed_dim, - batches, + CUBLAS_OP_N, + embed_dim, + batches, embed_dim, static_cast(&alpha), static_cast(output_weights.data_ptr()), - HIP_R_16F /*a_type*/, + HIP_R_16F /*a_type*/, embed_dim, static_cast(dropout_add_grads.data_ptr()), - HIP_R_16F /*b_type*/, - embed_dim, + HIP_R_16F /*b_type*/, + embed_dim, static_cast(&beta), static_cast(output_lin_grads.data_ptr()), - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - + // Output Linear Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, embed_dim, - batches, + embed_dim, + batches, static_cast(&alpha), static_cast(matmul2_results.data_ptr()), - HIP_R_16F /*a_type*/, + HIP_R_16F /*a_type*/, embed_dim, static_cast(dropout_add_grads.data_ptr()), - HIP_R_16F /*b_type*/, - embed_dim, + HIP_R_16F /*b_type*/, + embed_dim, static_cast(&beta), static_cast(output_weight_grads.data_ptr()), - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); // MatMul2 Dgrad1 - gemm_switch_fp32accum( a_layout_t, - b_layout_n, + gemm_switch_fp32accum( a_layout_t, + b_layout_n, k_seq_len, q_seq_len, head_dim, - alpha, + alpha, static_cast(v_lin_results_ptr), - lead_dim, + lead_dim, batch_stride, static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - beta, + head_dim*attn_batches, + head_dim, + beta, static_cast(matmul2_grads.data_ptr()), - k_seq_len, + k_seq_len, k_seq_len*q_seq_len, attn_batches ); - + // Matmul2 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - alpha, + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + alpha, static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, + head_dim*attn_batches, + head_dim, static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - v_lin_grads_ptr, - lead_dim, - batch_stride, + k_seq_len, + k_seq_len*q_seq_len, + beta, + v_lin_grads_ptr, + lead_dim, + batch_stride, attn_batches ); - // Apply Dropout Mask and Scale by Dropout Probability + // Apply Dropout Mask and Scale by Dropout Probability apex_masked_scale_cuda( static_cast(matmul2_grads.data_ptr()), static_cast(matmul2_grads.data_ptr()), @@ -411,75 +411,75 @@ std::vector bwd_cuda( assert(softmax_success); // Matmul1 Dgrad1 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - scale, - k_lin_results_ptr, - lead_dim, - batch_stride, + gemm_switch_fp32accum( a_layout_n, + b_layout_n, + head_dim, + q_seq_len, + k_seq_len, + scale, + k_lin_results_ptr, + lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - q_lin_grads_ptr, - lead_dim, + k_seq_len, + k_seq_len*q_seq_len, + beta, + q_lin_grads_ptr, + lead_dim, batch_stride, attn_batches ); - + // Matmul1 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - scale, - q_lin_results_ptr, - lead_dim, - batch_stride, + gemm_switch_fp32accum( a_layout_n, + b_layout_t, + head_dim, + k_seq_len, + q_seq_len, + scale, + q_lin_results_ptr, + lead_dim, + batch_stride, static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - k_lin_grads_ptr, - lead_dim, + k_seq_len, + k_seq_len*q_seq_len, + beta, + k_lin_grads_ptr, + lead_dim, batch_stride, attn_batches ); - // Input Linear Dgrad + // Input Linear Dgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_N, embed_dim, - batches, + batches, output_lin_dim, static_cast(&alpha), static_cast(input_weights.data_ptr()), - HIP_R_16F /*a_type*/, + HIP_R_16F /*a_type*/, embed_dim, static_cast(q_lin_grads_ptr), - HIP_R_16F /*b_type*/, - output_lin_dim, + HIP_R_16F /*b_type*/, + output_lin_dim, static_cast(&beta), //static_cast(input_grads.data_ptr()), static_cast(input_lin_grads.data_ptr()), - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ )); - - // Input Linear Wgrad + + // Input Linear Wgrad TORCH_CUDABLAS_CHECK(hipblasGemmEx( handle, - CUBLAS_OP_N, + CUBLAS_OP_N, CUBLAS_OP_T, - embed_dim, + embed_dim, output_lin_dim, - batches, + batches, static_cast(&alpha), //static_cast(inputs.data_ptr()), static_cast(lyr_nrm_results.data_ptr()), @@ -490,7 +490,7 @@ std::vector bwd_cuda( output_lin_dim, static_cast(&beta), static_cast(input_weight_grads.data_ptr()), - HIP_R_16F /*c_type*/, + HIP_R_16F /*c_type*/, embed_dim, HIPBLAS_COMPUTE_32F, HIPBLAS_GEMM_DEFAULT /*algo*/ @@ -517,5 +517,5 @@ std::vector bwd_cuda( } } // end namespace rocblas_gemmex -} // end namespace self_norm_add +} // end namespace self_norm_add } // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/softmax.cuh b/apex/contrib/csrc/multihead_attn/softmax.cuh index 6e7da0f71..174b57ce3 100644 --- a/apex/contrib/csrc/multihead_attn/softmax.cuh +++ b/apex/contrib/csrc/multihead_attn/softmax.cuh @@ -18,7 +18,7 @@ #include #include #include - + #ifdef USE_ROCM #define APEX_WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width) #else diff --git a/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh b/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh index 5d45efb3c..66bbf7b9c 100644 --- a/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh +++ b/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh @@ -47,27 +47,27 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, TORCH_CUDABLAS_CHECK(hipblasGemmStridedBatchedEx( handle, - opa, - opb, - (int)m, - (int)n, + opa, + opb, + (int)m, + (int)n, (int)k, - (void*)&fAlpha, - a, - HIP_R_16F /*a_type*/, - (int)lda, + (void*)&fAlpha, + a, + HIP_R_16F /*a_type*/, + (int)lda, strideA, - b, - HIP_R_16F /*b_type*/, - (int)ldb, + b, + HIP_R_16F /*b_type*/, + (int)ldb, strideB, - (void*)&fBeta, - c, - HIP_R_16F /*c_type*/, - (int)ldc, + (void*)&fBeta, + c, + HIP_R_16F /*c_type*/, + (int)ldc, strideC, (int)batchCount, - HIPBLAS_COMPUTE_32F, + HIPBLAS_COMPUTE_32F, algo)); } @@ -136,7 +136,7 @@ void HgemmStridedBatched(char transa, char transb, long m, // gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, // b, ldb, strideB, beta, c, ldc, strideC, batchCount); - gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, + gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); } diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh index 6d29420b2..a6c0b00be 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh @@ -22,8 +22,8 @@ namespace apex { namespace contrib { namespace nccl_p2p { at::Tensor get_unique_nccl_id(int n); int init_nccl_comm( - at::Tensor unique_nccl_id, - int my_rank, + at::Tensor unique_nccl_id, + int my_rank, int num_ranks ); void left_right_halo_exchange_inplace( @@ -38,7 +38,7 @@ std::vector left_right_halo_exchange( int handle, int left_rank, int right_rank, - at::Tensor left_output_halo, + at::Tensor left_output_halo, at::Tensor right_output_halo); void add_delay(int delay); }}} diff --git a/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu b/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu index 18b60264a..c5f7c087a 100644 --- a/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu +++ b/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu @@ -386,7 +386,7 @@ void fused_adam_cuda_mt( C10_CUDA_CHECK(cudaGetLastError()); } -template +template __device__ void convert(const FROM_T vi, TO_T& vo) { vo = static_cast(vi); @@ -522,7 +522,7 @@ __global__ void strided_check_finite_cuda_kernel( } } -template +template __global__ void maybe_cast_kernel( volatile int* overflow_flag, const FROM_T* p_in, diff --git a/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu b/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu index 3bb93b031..32b4eaf6b 100644 --- a/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu +++ b/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu @@ -164,7 +164,7 @@ struct LAMBStage2Functor MATH_T ratio = learning_rate; // apply adaptive learning rate to parameters with non-zero weight decay - if (decay != 0.0) + if (decay != 0.0) { float param_norm = per_tensor_param_norm[tensor_num]; float update_norm = per_tensor_update_norm[tensor_num]; diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu index 95ee009b2..8754a562c 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu @@ -24,7 +24,7 @@ __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int s ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; } -template +template __device__ void convert(const FROM_T vi, TO_T& vo) { vo = static_cast(vi); diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu index 188900128..98ec88c7e 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu +++ b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu @@ -131,7 +131,7 @@ void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride } } -template +template __device__ void __zero(T* dst) { *dst = T(0); @@ -146,8 +146,8 @@ __device__ void __zero(int4* dst) template __device__ void strided_copy_kernel( - T* dst, const int dst_stride_C, const int dst_stride_H, const int dst_stride_W, - const T* src, const int src_stride_C, const int src_stride_H, const int src_stride_W, + T* dst, const int dst_stride_C, const int dst_stride_H, const int dst_stride_W, + const T* src, const int src_stride_C, const int src_stride_H, const int src_stride_W, const int NC, const int NH, const int NW ) { @@ -179,7 +179,7 @@ __device__ void strided_copy_kernel( } } -template +template __device__ void checked_signal( volatile int* signal1_flag, volatile int* signal2_flag, const int v1, const int v2, const int v3, const int v4 @@ -325,7 +325,7 @@ __device__ void wait_for( r4 = __builtin_nontemporal_load(wait_flag + 3); #else asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory"); -#endif +#endif } while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4); } cg::this_grid().sync(); // all threads wait for main diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh index 4f0169f3d..689bbe1e0 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh +++ b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh @@ -17,7 +17,7 @@ #pragma once #include #ifndef _peer_memory_h_ -#define _peer_memory_h_ +#define _peer_memory_h_ namespace apex { namespace contrib { namespace peer_memory { int64_t allocate_raw(int64_t size); diff --git a/apex/contrib/csrc/transducer/transducer_joint.cpp b/apex/contrib/csrc/transducer/transducer_joint.cpp index 351e7cab7..dd2df88fd 100755 --- a/apex/contrib/csrc/transducer/transducer_joint.cpp +++ b/apex/contrib/csrc/transducer/transducer_joint.cpp @@ -50,9 +50,9 @@ std::vector transducer_joint_forward( if (packOutput) CHECK_INPUT(batchOffset); return transducer_joint_cuda_forward( - f, - g, - fLen, + f, + g, + fLen, gLen, batchOffset, packedBatch, @@ -81,8 +81,8 @@ std::vector transducer_joint_backward( if (packOutput) CHECK_INPUT(batchOffset); return transducer_joint_cuda_backward( - in, - fLen, + in, + fLen, gLen, batchOffset, maxFLen, diff --git a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu index 7c1c7c291..122767e5c 100755 --- a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu @@ -44,7 +44,7 @@ inline int largestPowerOfTwo(int x){ /* Figure out vectorization type for masks. Similar to how PyTorch figures out acc_t here: -aten/src/ATen/AccumulateType.h +aten/src/ATen/AccumulateType.h */ template struct MaskVecType { }; @@ -60,10 +60,10 @@ using mvec_type = typename MaskVecType::type; // For fwd, batch offset and stride are different for packing and non-packing mode. struct OffsetCalFwd{ __device__ __forceinline__ OffsetCalFwd( - int64_t batch, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, + int64_t batch, + const int64_t *batchOffset, + int64_t maxFLen, + int64_t maxGLen, int64_t gLen, int64_t hiddenSize, bool packOutput) : @@ -75,7 +75,7 @@ struct OffsetCalFwd{ hiddenSize(hiddenSize), packOutput(packOutput) {} - + int64_t batch; const int64_t *batchOffset; int64_t maxFLen; @@ -85,7 +85,7 @@ struct OffsetCalFwd{ bool packOutput; __device__ __forceinline__ int64_t getBatchOffset(){ - return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize + return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize : batch*maxFLen*maxGLen*hiddenSize; } @@ -93,7 +93,7 @@ struct OffsetCalFwd{ return packOutput ? gLen*hiddenSize : maxGLen*hiddenSize; } - + }; // Helper class to calculate pointer offset that can be shared by different flavors of kernels @@ -102,12 +102,12 @@ struct OffsetCalFwd{ // according to bwdFasterDim can lead to a unified implementation in the actual kernel. struct OffsetCalBwd{ __device__ __forceinline__ OffsetCalBwd( - int64_t batch, - const int64_t *batchOffset, - const int *fLen, + int64_t batch, + const int64_t *batchOffset, + const int *fLen, const int *gLen, - int64_t maxFLen, - int64_t maxGLen, + int64_t maxFLen, + int64_t maxGLen, int64_t hiddenSize, bool packOutput, bool bwdFasterDim) : @@ -133,7 +133,7 @@ struct OffsetCalBwd{ bool bwdFasterDim; // whether doing bwd on the faster moving dimension __device__ __forceinline__ int64_t getBatchOffset(){ - return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize + return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize : batch*maxFLen*maxGLen*hiddenSize; } @@ -148,7 +148,7 @@ struct OffsetCalBwd{ __device__ __forceinline__ auto getMyYLen() -> decltype(gLen[batch]){ return bwdFasterDim ? fLen[batch] : gLen[batch]; } - + __device__ __forceinline__ int64_t getStrideX(){ return bwdFasterDim ? hiddenSize : ((packOutput ? gLen[batch] : maxGLen) * hiddenSize); } @@ -160,7 +160,7 @@ struct OffsetCalBwd{ // Vanila transducer joint forward kernel -// Detail of this joint function can be found in: +// Detail of this joint function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. // f is a tensor of shape [batch, T, H] @@ -211,7 +211,7 @@ __global__ void transducer_joint_forward( } else if (packOutput == false and t < maxFLen and u < maxGLen){ // Need to write finite data to don't-care region because we instantiate the result tensor - // with torch::empty for performance reasons. Even though it is don't-care region, the + // with torch::empty for performance reasons. Even though it is don't-care region, the // contents need to be finite, otherwise could lead to NaN in WGRAD. // In packing mode, this write is no longer necessary as we remove the don't-care region // from the output. @@ -221,13 +221,13 @@ __global__ void transducer_joint_forward( if (h < hiddenSize){ mySum[h] = -1; } - } + } } } /* Tiled version of the joint forward kernel -Detail of this joint function can be found in: +Detail of this joint function can be found in: [1] Sequence Transduction with Recurrent Neural Networks. f is a tensor of shape [batch, T, H] @@ -236,17 +236,17 @@ the transducer joint does sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1) The resultant tensor is of shape [batch, T, U, H] Each thread is working on a tile of the shape of tileF x tileG in the result tensor. -The input for the tile is first loaded in the register and is reused tileG and tileF times. +The input for the tile is first loaded in the register and is reused tileG and tileF times. This joint function can optionally pack the output where the output tensor with a shape of [B, T, U, H] is packed into [B_packed, H]. Don't-care region (t > fLen) or (u > gLen) is removed. To enable packing, the starting offset for each batch need to be specified with batchOffset. -Optionally this joint function performs ReLU and/or dropout on the joint output, which is +Optionally this joint function performs ReLU and/or dropout on the joint output, which is controlled by arguments relu and dropout, respectively. philoxArgs is argument used for generating pseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint -function is a masked operation, which is controlled by the template argument masked. In this case, +function is a masked operation, which is controlled by the template argument masked. In this case, masks are saved to backward. */ template @@ -261,7 +261,7 @@ __global__ void transducer_joint_tiled_forward( int64_t hiddenSize, int64_t hiddenPerBlock, bool packOutput, - bool relu, + bool relu, bool dropout, float p, at::PhiloxCudaState philoxArgs, @@ -289,18 +289,18 @@ __global__ void transducer_joint_tiled_forward( uint8_t *myMask = mask + myBatchOffset + t*strideF + u*hiddenSize + hOffset; // The following code is only needed for dropout. We try to bypass them as much as possible. - auto seeds = masked ? at::cuda::philox::unpack(philoxArgs) + auto seeds = masked ? at::cuda::philox::unpack(philoxArgs) : std::make_tuple(static_cast(0), static_cast(0)); - uint64_t tid = masked ? (static_cast(blockIdx.z)*gridDim.y*gridDim.x + + uint64_t tid = masked ? (static_cast(blockIdx.z)*gridDim.y*gridDim.x + blockIdx.y*gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x : 0; - Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds)); - scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0; + Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds)); + scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0; bool dropoutMask[U]; - if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){ + if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){ // register buffers for tiled input reuse - scalar_t fBuffer[tileF], gBuffer[tileG]; + scalar_t fBuffer[tileF], gBuffer[tileG]; for (int i = 0; i < tileF; ++i){ if (t + i < myFLen) fBuffer[i] = myF[i*hiddenSize + h]; @@ -370,7 +370,7 @@ Bwd operation (reduction) on one input tensor. Since the operation performed for tensors are exactly the same, only one kernel is needed, and the different indexing offsets and strides are handled by OffsetCalBwd. -When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a +When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a non-packed form. When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, @@ -404,21 +404,21 @@ __device__ void transducer_joint_single_backward( extern __shared__ char smem8[]; auto smem = reinterpret_cast(smem8); - OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, + OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, bwdFasterDim); const auto maxXLen = offsetCal.getMaxXLen(); const auto myXLen = offsetCal.getMyXLen(); const auto myYLen = offsetCal.getMyYLen(); scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset; - + if (x < myXLen){ - + const auto myBatchOffset = offsetCal.getBatchOffset(); const auto strideX = offsetCal.getStrideX(); const auto strideY = offsetCal.getStrideY(); const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset; const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset : nullptr; - + // Each warp reduces numYPerWarp "y" first acc_t warpSum = 0; auto numYPerWarp = (myYLen+numWarp-1)/numWarp; @@ -428,7 +428,7 @@ __device__ void transducer_joint_single_backward( if (y < myYLen and (hOffset+lid) < hiddenSize) if (masked) warpSum += static_cast(myGrad[y*strideY + lid]) * myMask[y*strideY + lid] * scale; - else + else warpSum += myGrad[y*strideY + lid]; } @@ -458,8 +458,8 @@ __device__ void transducer_joint_single_backward( /* Actual bwd (reduction) kernel get launched. -Call transducer_joint_single_backward twice on two input tensors. -The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op +Call transducer_joint_single_backward twice on two input tensors. +The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op uses the rest. When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, and mask contains the mask information. @@ -508,14 +508,14 @@ __global__ void transducer_joint_combined_backward( scale, gGrad, maxFLen); - } + } } /* Vectorized version of transducer_joint_single_backward Doing exact same operation as transducer_joint_single_backward except the load and store are vectorized. -When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a +When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a non-packed form. When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, and mask contains the mask information. @@ -546,7 +546,7 @@ __device__ void transducer_joint_single_vec_backward( // Figure out the vectorization type for mask using mvec_t = mvec_type; - OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, + OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, bwdFasterDim); const auto maxXLen = offsetCal.getMaxXLen(); const auto myXLen = offsetCal.getMyXLen(); @@ -597,7 +597,7 @@ __device__ void transducer_joint_single_vec_backward( } } } - + // transpose partial sum in SMEM and reduce further using warpReduce for (int i = 0; i < V; ++i){ smem[lid*numWarp + wid] = warpSum[i]; @@ -620,7 +620,7 @@ __device__ void transducer_joint_single_vec_backward( // example of 4 warps (a, b, c, d) with 8 threads per warp // Each warp need 8 / 4 = 2 threads to write the results. if (lid % numWarp == 0 and hOffset+(wid*C10_WARP_SIZE/numWarp + lid/numWarp)*V < hiddenSize) - myInGradVec[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = *outBufferVec; + myInGradVec[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = *outBufferVec; } else if (wid == 0 and hOffset + lid*V < hiddenSize){ // Need to ensure the grad is zero for don't care region @@ -630,8 +630,8 @@ __device__ void transducer_joint_single_vec_backward( /* Vecotrized version of transducer_joint_combined_backward -Call transducer_joint_single_vec_backward twice on two input tensors. -The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op +Call transducer_joint_single_vec_backward twice on two input tensors. +The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op uses the rest. When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, and mask contains the mask information. @@ -680,7 +680,7 @@ __global__ void transducer_joint_combined_vec_backward( scale, gGrad, maxFLen); - } + } } @@ -700,7 +700,7 @@ std::vector transducer_joint_cuda_forward( float dropoutProb, int tileSize){ - + auto tensorOpt = f.options(); auto dtype = f.scalar_type(); const auto batchSize = f.size(0); @@ -708,7 +708,7 @@ std::vector transducer_joint_cuda_forward( const auto maxGLen = g.size(1); const auto hiddenSize = f.size(2); bool masked = dropout or relu; - + int64_t *batchOffsetPtr = nullptr; torch::Tensor sum, mask; auto maskOpt = tensorOpt.dtype(torch::kUInt8); @@ -719,7 +719,7 @@ std::vector transducer_joint_cuda_forward( mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt); } else{ - sum = torch::empty({packedBatch, hiddenSize}, tensorOpt); + sum = torch::empty({packedBatch, hiddenSize}, tensorOpt); batchOffsetPtr = batchOffset.data_ptr(); if (masked) mask = torch::empty({packedBatch, hiddenSize}, maskOpt); @@ -732,7 +732,7 @@ std::vector transducer_joint_cuda_forward( // Simple heuristics const int numThread = std::min(128, (static_cast(hiddenSize)+at::cuda::warp_size()-1) / at::cuda::warp_size() * at::cuda::warp_size()); - + if (opt == 0){ // vanilla kernel const int threads = numThread; @@ -741,41 +741,41 @@ std::vector transducer_joint_cuda_forward( AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] { transducer_joint_forward <<>>( - f.data_ptr(), - g.data_ptr(), - fLen.data_ptr(), - gLen.data_ptr(), + f.data_ptr(), + g.data_ptr(), + fLen.data_ptr(), + gLen.data_ptr(), batchOffsetPtr, maxFLen, maxGLen, hiddenSize, packOutput, sum.data_ptr()); - })); + })); } if (opt == 1){ - // tiled version. For simplicity, assume tileF == tileG, even though the kernel can + // tiled version. For simplicity, assume tileF == tileG, even though the kernel can // support more general cases. const int threads = numThread; const int hiddenPerBlock = numThread; const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock; - const dim3 blocks( (maxGLen+tileSize-1)/tileSize * hiddenBlock, - (maxFLen+tileSize-1)/tileSize, + const dim3 blocks( (maxGLen+tileSize-1)/tileSize * hiddenBlock, + (maxFLen+tileSize-1)/tileSize, batchSize); - TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4, + TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4, "Expected tileSize to be in [1, 2, 4], but got ", tileSize); at::PhiloxCudaState rng_engine_inputs; if (masked){ - // set up PRG when the input is masked. rng_engine_inputs will be used as a space filler + // set up PRG when the input is masked. rng_engine_inputs will be used as a space filler // for non-masked calls. // Therefore no need to initialize. c10::optional gen_; - auto gen = at::get_generator_or_default(gen_, + auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // counterOffset records how many cuRAND calls each thread makes. For a tiled kernel, - // each thread processes tileF * tileG output elements. + // counterOffset records how many cuRAND calls each thread makes. For a tiled kernel, + // each thread processes tileF * tileG output elements. int64_t counterOffset = tileSize * tileSize; { std::lock_guard lock(gen->mutex_); @@ -784,17 +784,17 @@ std::vector transducer_joint_cuda_forward( } AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] { - void(*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*, - int64_t, int64_t, int64_t, int64_t, bool, bool, bool, float, + void(*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*, + int64_t, int64_t, int64_t, int64_t, bool, bool, bool, float, at::PhiloxCudaState, scalar_t*, uint8_t*); if (masked){ switch (tileSize){ case 2: - kernel = &transducer_joint_tiled_forward; break; case 4: - kernel = &transducer_joint_tiled_forward; break; } @@ -802,20 +802,20 @@ std::vector transducer_joint_cuda_forward( else{ switch (tileSize){ case 1: - kernel = &transducer_joint_tiled_forward; break; case 2: - kernel = &transducer_joint_tiled_forward; break; case 4: - kernel = &transducer_joint_tiled_forward; break; } } - + kernel<<>>( f.data_ptr(), g.data_ptr(), @@ -833,11 +833,11 @@ std::vector transducer_joint_cuda_forward( rng_engine_inputs, sum.data_ptr(), maskPtr); - })); + })); } - + C10_CUDA_CHECK(cudaGetLastError()); - if (masked) + if (masked) return {sum, mask}; else return {sum}; @@ -868,16 +868,16 @@ std::vector transducer_joint_cuda_backward( torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); - int64_t *batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr(); + int64_t *batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr(); // The number "y" I would like each thread to work on const int workPerThread = 32; // Since the bwd for f and g have the same thread block size, we need to use the max of the two. int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread-1) / workPerThread); - // Would like to have at least 2 warps + // Would like to have at least 2 warps numWarp = std::max(2, numWarp); // cap on the maximum number of warps allowed - numWarp = std::min(maxNumWarp, numWarp); + numWarp = std::min(maxNumWarp, numWarp); // Need smem for transposing the partial sum. The partial sum is in a matrix of the shape // numWarp x warpSize @@ -887,7 +887,7 @@ std::vector transducer_joint_cuda_backward( AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_cuda_backward_kernel", ([&] { auto gradPtr = grad.data_ptr(); auto fLenPtr = fLen.data_ptr(); - auto gLenPtr = gLen.data_ptr(); + auto gLenPtr = gLen.data_ptr(); auto fGradPtr = fGrad.data_ptr(); auto gGradPtr = gGrad.data_ptr(); @@ -899,15 +899,15 @@ std::vector transducer_joint_cuda_backward( constexpr int vecAlignment = std::alignment_of::value; // if all input and output tensors meet the alignment requirement - bool memAlign = (reinterpret_cast(gradPtr) % vecAlignment == 0) - and (reinterpret_cast(fGradPtr) % vecAlignment == 0) + bool memAlign = (reinterpret_cast(gradPtr) % vecAlignment == 0) + and (reinterpret_cast(fGradPtr) % vecAlignment == 0) and (reinterpret_cast(gGradPtr) % vecAlignment == 0); if (vectFactor > 1 and hiddenSize%vectFactor == 0 and memAlign){ - // If vectorization helps and the alignment requirement is met, use the vectorized + // If vectorization helps and the alignment requirement is met, use the vectorized // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor. const dim3 blocks( (hiddenSize+at::cuda::warp_size()*vectFactor-1)/(at::cuda::warp_size()*vectFactor), - maxFLen+maxGLen, + maxFLen+maxGLen, batchSize); if (masked){ transducer_joint_combined_vec_backward @@ -915,9 +915,9 @@ std::vector transducer_joint_cuda_backward( <<>>( gradPtr, maskPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, + fLenPtr, + gLenPtr, + batchOffsetPtr, maxFLen, maxGLen, hiddenSize, @@ -932,16 +932,16 @@ std::vector transducer_joint_cuda_backward( <<>>( gradPtr, maskPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, + fLenPtr, + gLenPtr, + batchOffsetPtr, maxFLen, maxGLen, hiddenSize, packOutput, scale, fGradPtr, - gGradPtr); + gGradPtr); } } else{ @@ -952,9 +952,9 @@ std::vector transducer_joint_cuda_backward( <<>>( gradPtr, maskPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, + fLenPtr, + gLenPtr, + batchOffsetPtr, maxFLen, maxGLen, hiddenSize, @@ -968,9 +968,9 @@ std::vector transducer_joint_cuda_backward( <<>>( gradPtr, maskPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, + fLenPtr, + gLenPtr, + batchOffsetPtr, maxFLen, maxGLen, hiddenSize, @@ -980,7 +980,7 @@ std::vector transducer_joint_cuda_backward( gGradPtr); } } - })); + })); return {fGrad, gGrad}; } diff --git a/apex/contrib/csrc/transducer/transducer_loss.cpp b/apex/contrib/csrc/transducer/transducer_loss.cpp index 91c956239..e304edf93 100644 --- a/apex/contrib/csrc/transducer/transducer_loss.cpp +++ b/apex/contrib/csrc/transducer/transducer_loss.cpp @@ -51,13 +51,13 @@ std::vector transducer_loss_forward( if (packedInput) CHECK_INPUT(batchOffset); return transducer_loss_cuda_forward( - x, - label, - fLen, - yLen, + x, + label, + fLen, + yLen, batchOffset, maxFLen, - blankIdx, + blankIdx, opt, packedInput); } diff --git a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu index 295e14b3f..2b3584b48 100755 --- a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu @@ -15,15 +15,15 @@ __device__ __forceinline__ scalar_t logSumExp(scalar_t a, scalar_t b) { } // Vanilla transducer loss function (i.e. forward-backward algorithm) -// Detail of this loss function can be found in: +// Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. // Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted // into log scale by the preceding log_softmax layer -// Diagonal wavefront advancing usually used in dynamic programming is leveraged here. +// Diagonal wavefront advancing usually used in dynamic programming is leveraged here. // alpha and beta are of acc_t type, as they are essentially accumulators. -// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into +// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into // [B_packed, H]. // Don't-care region (t > audLen) or (u > txtLen) is removed. // To support the packed input, the starting offsets for each batch need to be specified with @@ -48,18 +48,18 @@ __global__ void transducer_loss_forward( const int tid = threadIdx.x; const auto myFLen = audLen[batch]; // Note that start of the sentence is added as 1 here - const auto myGLen = txtLen[batch] + 1; + const auto myGLen = txtLen[batch] + 1; const auto myLabel = label + batch * (maxGLen-1); - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) + const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) : batch * maxFLen * maxGLen; const int64_t myStrideT = packedInput ? myGLen : maxGLen; - const scalar_t* myX = x + myBatchOffset * dictSize; + const scalar_t* myX = x + myBatchOffset * dictSize; int u = tid; if (blockIdx.x == 0){ // alpha path acc_t* myAlpha = alpha + batch*maxFLen*maxGLen; - if (u == 0) + if (u == 0) myAlpha[0] = 0; __syncthreads(); @@ -71,7 +71,7 @@ __global__ void transducer_loss_forward( // Eq(16) in [1] if (u == 0){ // alpha(t, u) = alpha(t-1, u) * null(t-1, u) - myAlpha[t*maxGLen + u] = myAlpha[(t-1)*maxGLen] + myAlpha[t*maxGLen + u] = myAlpha[(t-1)*maxGLen] + myX[((t-1)*myStrideT) * dictSize + blankIdx]; } else if (t == 0){ @@ -80,9 +80,9 @@ __global__ void transducer_loss_forward( } else{ // alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1) - acc_t current = myAlpha[(t-1)*maxGLen + u] + acc_t current = myAlpha[(t-1)*maxGLen + u] + myX[((t-1)*myStrideT + u) * dictSize + blankIdx]; - acc_t next = myAlpha[t*maxGLen + u - 1] + acc_t next = myAlpha[t*maxGLen + u - 1] + myX[(t*myStrideT + u - 1) * dictSize + myLabel[u - 1]]; myAlpha[t*maxGLen + u] = logSumExp(next, current); } @@ -95,7 +95,7 @@ __global__ void transducer_loss_forward( // beta path acc_t* myBeta = beta + batch*maxFLen*maxGLen; if (u == 0){ - myBeta[(myFLen-1)*maxGLen + myGLen - 1] = myX[((myFLen-1)*myStrideT + myBeta[(myFLen-1)*maxGLen + myGLen - 1] = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx]; } __syncthreads(); @@ -107,19 +107,19 @@ __global__ void transducer_loss_forward( // Eq(18) in [1] if (u == myGLen - 1){ // beta(t, u) = beta(t+1, u) * null(t, u) - myBeta[t*maxGLen + u] = myBeta[(t+1)*maxGLen + u] + myBeta[t*maxGLen + u] = myBeta[(t+1)*maxGLen + u] + myX[(t*myStrideT + u) * dictSize + blankIdx]; } else if (t == myFLen - 1){ // beta(t, u) = beta(t, u+1) * y(t, u) - myBeta[t*maxGLen + u] = myBeta[t*maxGLen + u + 1] + myBeta[t*maxGLen + u] = myBeta[t*maxGLen + u + 1] + myX[(t*myStrideT + u) * dictSize + myLabel[u]]; } else{ // beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u) - acc_t current = myBeta[(t+1)*maxGLen + u] + acc_t current = myBeta[(t+1)*maxGLen + u] + myX[(t*myStrideT + u) * dictSize + blankIdx]; - acc_t next = myBeta[t*maxGLen + u + 1] + acc_t next = myBeta[t*maxGLen + u + 1] + myX[(t*myStrideT + u) * dictSize + myLabel[u]]; myBeta[t*maxGLen + u] = logSumExp(next, current); } @@ -128,7 +128,7 @@ __global__ void transducer_loss_forward( __syncthreads(); } if (tid == 0) - loss[batch] = -myBeta[0]; + loss[batch] = -myBeta[0]; } } @@ -140,14 +140,14 @@ __global__ void transducer_loss_forward( // For simplicity, this kernel currently only supports U <= maxThread, which should be the common // case. For cases where U > maxThread, the vanilla kernel is used as a fallback option. -// Detail of this loss function can be found in: +// Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. // Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted // into log scale by the preceding log_softmax layer // Diagonal wavefront advancing usually used in dynamic programming is leveraged here. // alpha and beta are of acc_t type, as they are essentially accumulators. -// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into +// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into // [B_packed, H]. // Don't-care region (t > audLen) or (u > txtLen) is removed. // To support the packed input, the starting offsets for each batch need to be specified with @@ -172,10 +172,10 @@ __global__ void transducer_loss_batch_load_forward( int u = threadIdx.x; const auto myFLen = audLen[batch]; const auto myGLen = txtLen[batch] + 1; - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) + const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) : batch * maxFLen * maxGLen; const int64_t myStrideT = packedInput ? myGLen : maxGLen; - const scalar_t* myX = x + myBatchOffset * dictSize; + const scalar_t* myX = x + myBatchOffset * dictSize; scalar_t next[batchLdSize], current[batchLdSize]; extern __shared__ char smem8[]; auto smem = reinterpret_cast(smem8); @@ -186,7 +186,7 @@ __global__ void transducer_loss_batch_load_forward( // two SMEM regions for double buffering read and write data to avoid data race acc_t * const sharedAlpha[2] = {smem, smem+maxGLen}; - sharedAlpha[0][u] = 0; + sharedAlpha[0][u] = 0; __syncthreads(); if (u == 0) @@ -302,21 +302,21 @@ __global__ void transducer_loss_batch_load_forward( sharedBetaWr[u] = prvStepBeta; myBeta[t*maxGLen + u] = prvStepBeta; } - + } __syncthreads(); } } if (u == 0) - loss[batch] = -prvStepBeta; + loss[batch] = -prvStepBeta; } } // Vanilla transudcer loss backward operation. -// Detail of this loss function can be found in: +// Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. -// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere, +// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere, // hence only Eq(20) in [1] is implemented in this kernel. // Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time @@ -347,37 +347,37 @@ __global__ void transducer_loss_backward( const int batch = blockIdx.y; const int64_t myFLen = audLen[batch]; const int64_t myGLen = txtLen[batch] + 1; - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) + const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) : batch * maxFLen * maxGLen; const int64_t myStrideT = packedInput ? myGLen : maxGLen; auto myX = x + (myBatchOffset + t*myStrideT)*dictSize; auto myAlpha = alpha + batch*maxFLen*maxGLen; auto myBeta = beta + batch*maxFLen*maxGLen; - auto myXGrad = xGrad + (myBatchOffset + t*myStrideT)*dictSize; + auto myXGrad = xGrad + (myBatchOffset + t*myStrideT)*dictSize; auto myLabel = label + batch*(maxGLen-1); int64_t u = tid; while (t < myFLen and u < myGLen){ // Do the update // loss = -ln(Pr(y*|x)) - acc_t grad = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0]; + acc_t grad = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0]; if (u != myGLen - 1) - myXGrad[u*dictSize + myLabel[u]] = -std::exp(grad + myBeta[t*maxGLen + u + 1] + myXGrad[u*dictSize + myLabel[u]] = -std::exp(grad + myBeta[t*maxGLen + u + 1] + myX[u*dictSize + myLabel[u]]); if (t == myFLen - 1 and u == myGLen - 1) myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myX[u*dictSize + blankIdx]); else if (t != myFLen - 1) - myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myBeta[(t+1)*maxGLen + u] - + myX[u*dictSize + blankIdx]); + myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myBeta[(t+1)*maxGLen + u] + + myX[u*dictSize + blankIdx]); u += blockDim.x; } } // Fused transudcer loss backward operation. -// Detail of this loss function can be found in: +// Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. -// The bwd op of the preceding softmax layer is fused in this kernel. +// The bwd op of the preceding softmax layer is fused in this kernel. // Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time // To support the packed input, the starting offsets for each batch need to be specified with @@ -398,22 +398,22 @@ __global__ void transducer_loss_fused_backward( int64_t maxGLen, bool packedInput, scalar_t* xGrad) { - + const int tid = threadIdx.x; const int u = blockIdx.x; const int t = blockIdx.y; const int batch = blockIdx.z; const int64_t myFLen = audLen[batch]; const int64_t myGLen = txtLen[batch] + 1; - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) + const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) : batch * maxFLen * maxGLen; const int64_t myStrideT = packedInput ? myGLen : maxGLen; __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared; - auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; + auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; - if (t < myFLen and u < myGLen){ - auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; + if (t < myFLen and u < myGLen){ + auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; auto myAlpha = alpha + batch*maxFLen*maxGLen; auto myBeta = beta + batch*maxFLen*maxGLen; auto myLabel = label + batch*(maxGLen-1); @@ -455,9 +455,9 @@ __global__ void transducer_loss_fused_backward( // Vectorized version of fused transudcer loss backward operation. -// Detail of this loss function can be found in: +// Detail of this loss function can be found in: // [1] Sequence Transduction with Recurrent Neural Networks. -// The bwd op of the preceding softmax layer is fused in this kernel. +// The bwd op of the preceding softmax layer is fused in this kernel. // Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time // To support the packed input, the starting offsets for each batch need to be specified with @@ -478,20 +478,20 @@ __global__ void transducer_loss_fused_vec_backward( int64_t maxGLen, bool packedInput, scalar_t* xGrad) { - + const int tid = threadIdx.x; const int u = blockIdx.x; const int t = blockIdx.y; const int batch = blockIdx.z; const int64_t myFLen = audLen[batch]; const int64_t myGLen = txtLen[batch] + 1; - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) + const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) : batch * maxFLen * maxGLen; const int64_t myStrideT = packedInput ? myGLen : maxGLen; __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared; - auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; - auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; + auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; + auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; auto myAlpha = alpha + batch*maxFLen*maxGLen; auto myBeta = beta + batch*maxFLen*maxGLen; auto myLabel = label + batch*(maxGLen-1); @@ -502,7 +502,7 @@ __global__ void transducer_loss_fused_vec_backward( auto myXGradVec = reinterpret_cast(myXGrad); auto myXBufferVec = reinterpret_cast(myXBuffer); auto myXGradBufferVec = reinterpret_cast(myXGradBuffer); - if (t < myFLen and u < myGLen){ + if (t < myFLen and u < myGLen){ // load and store shared variables in SMEM if (tid == 0){ commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0]; @@ -541,7 +541,7 @@ __global__ void transducer_loss_fused_vec_backward( // Store myXGrad in a vector form myXGradVec[h0/V] = *myXGradBufferVec; - + } } else if (!packedInput){ @@ -570,18 +570,18 @@ std::vector transducer_loss_cuda_forward( const int maxGLen = label.size(1) + 1; const int dictSize = x.size(-1); - TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize, - "Expected blank index to be in the range of 0 to ", + TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize, + "Expected blank index to be in the range of 0 to ", dictSize-1, - ", but got ", + ", but got ", blankIdx); - TORCH_CHECK(opt == -1 or opt == 0 or opt == 1, - "Got an invalid optimization level ", + TORCH_CHECK(opt == -1 or opt == 0 or opt == 1, + "Got an invalid optimization level ", opt); // The data type of alpha and beta will be resolved at dispatch time, // hence defined here and assigned later - torch::Tensor alpha; + torch::Tensor alpha; torch::Tensor beta; torch::Tensor loss = torch::empty({batchSize}, tensorOpt); const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); @@ -602,42 +602,42 @@ std::vector transducer_loss_cuda_forward( // if the required SMEM size or number threads exceeds the limit, fall back to the vanilla // kernel. const auto smemSize = 2*maxGLen*sizeof(acc_t); - const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0 + const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0 : (opt == -1) ? 1 : opt; const int threads = std::min(maxThreadPerBlock, maxGLen); - const dim3 blocks(2, batchSize, 1); + const dim3 blocks(2, batchSize, 1); if (optFallBack == 0) transducer_loss_forward<<>>( - x.data_ptr(), - label.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), + x.data_ptr(), + label.data_ptr(), + audLen.data_ptr(), + txtLen.data_ptr(), batchOffsetPtr, - dictSize, - blankIdx, + dictSize, + blankIdx, maxFLen, maxGLen, packedInput, - alpha.data_ptr(), - beta.data_ptr(), + alpha.data_ptr(), + beta.data_ptr(), loss.data_ptr()); else if (optFallBack == 1) transducer_loss_batch_load_forward <<>>( - x.data_ptr(), - label.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), + x.data_ptr(), + label.data_ptr(), + audLen.data_ptr(), + txtLen.data_ptr(), batchOffsetPtr, - dictSize, - blankIdx, + dictSize, + blankIdx, maxFLen, maxGLen, packedInput, - alpha.data_ptr(), - beta.data_ptr(), - loss.data_ptr()); + alpha.data_ptr(), + beta.data_ptr(), + loss.data_ptr()); })); C10_CUDA_CHECK(cudaGetLastError()); @@ -673,17 +673,17 @@ torch::Tensor transducer_loss_cuda_backward( const int warpSize = deviceProperties->warpSize; const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr() : nullptr; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - + if (fuseSoftmaxBackward){ - // alloc empty tensors for performance, hence need to ensure zeros are writtern to + // alloc empty tensors for performance, hence need to ensure zeros are writtern to // don't-care region in the kernel. xGrad = torch::empty_like(x); // Would like each thread to work on 4 hidden units - const int workPerThread = 4; + const int workPerThread = 4; // Don't want to have more than 128 threads per thread block const int maxThreadPerElmt = std::min(128, maxThreadPerBlock); - const int threads = std::min(maxThreadPerElmt, std::max(warpSize, + const int threads = std::min(maxThreadPerElmt, std::max(warpSize, (dictSize+workPerThread-1)/workPerThread)); const dim3 blocks(maxGLen, maxFLen, batchSize); @@ -694,44 +694,44 @@ torch::Tensor transducer_loss_cuda_backward( constexpr int vecAlignment = std::alignment_of::value; // if all input and output tensors meet the alignment requirement bool memAlign = reinterpret_cast(x.data_ptr()) % vecAlignment == 0 - and reinterpret_cast(xGrad.data_ptr()) + and reinterpret_cast(xGrad.data_ptr()) % vecAlignment == 0; if (vectFactor > 1 and dictSize%vectFactor == 0 and memAlign){ transducer_loss_fused_vec_backward - <<>>( - x.data_ptr(), + <<>>( + x.data_ptr(), lossGrad.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), + audLen.data_ptr(), + txtLen.data_ptr(), label.data_ptr(), - alpha.data_ptr(), - beta.data_ptr(), + alpha.data_ptr(), + beta.data_ptr(), batchOffsetPtr, - dictSize, - blankIdx, + dictSize, + blankIdx, maxFLen, maxGLen, packedInput, - xGrad.data_ptr()); + xGrad.data_ptr()); } else{ - transducer_loss_fused_backward<<>>( - x.data_ptr(), + transducer_loss_fused_backward<<>>( + x.data_ptr(), lossGrad.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), + audLen.data_ptr(), + txtLen.data_ptr(), label.data_ptr(), - alpha.data_ptr(), - beta.data_ptr(), + alpha.data_ptr(), + beta.data_ptr(), batchOffsetPtr, - dictSize, - blankIdx, + dictSize, + blankIdx, maxFLen, maxGLen, packedInput, - xGrad.data_ptr()); - + xGrad.data_ptr()); + } })); } @@ -744,17 +744,17 @@ torch::Tensor transducer_loss_cuda_backward( const dim3 blocks(maxFLen, batchSize); AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] { using acc_t = at::acc_type; - transducer_loss_backward<<>>( - x.data_ptr(), + transducer_loss_backward<<>>( + x.data_ptr(), lossGrad.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), + audLen.data_ptr(), + txtLen.data_ptr(), label.data_ptr(), - alpha.data_ptr(), - beta.data_ptr(), - batchOffsetPtr, - dictSize, - blankIdx, + alpha.data_ptr(), + beta.data_ptr(), + batchOffsetPtr, + dictSize, + blankIdx, maxFLen, maxGLen, packedInput, @@ -762,6 +762,6 @@ torch::Tensor transducer_loss_cuda_backward( })); } C10_CUDA_CHECK(cudaGetLastError()); - + return xGrad; } diff --git a/apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py b/apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py index f81522ab3..e28d3e2d3 100644 --- a/apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py +++ b/apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py @@ -62,48 +62,48 @@ for sequences in range(args.num_seqs_start, args.num_seqs_stop + args.num_seqs_inc, args.num_seqs_inc) : inputs = torch.randn(args.seq_length, sequences, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) grads = torch.randn_like(inputs) - + for trial in range(0, args.trials + args.warmup_trials) : layer_inputs = inputs evt_idx = trial - args.warmup_trials - + if evt_idx >= 0 : start_evt_fwd[evt_idx].record() - + for lyr_idx in range(0, args.layers) : if args.native : - outputs,_ = attn_layers[lyr_idx].forward(layer_inputs, - layer_inputs, - layer_inputs, - key_padding_mask=None, - need_weights=False, + outputs,_ = attn_layers[lyr_idx].forward(layer_inputs, + layer_inputs, + layer_inputs, + key_padding_mask=None, + need_weights=False, attn_mask=None) else : - outputs,_ = attn_layers[lyr_idx].forward(layer_inputs, - layer_inputs, + outputs,_ = attn_layers[lyr_idx].forward(layer_inputs, layer_inputs, - key_padding_mask=None, - need_weights=False, + layer_inputs, + key_padding_mask=None, + need_weights=False, attn_mask=None, is_training=True) layer_inputs = outputs - + if evt_idx >= 0 : start_evt_bwd[evt_idx].record() if not args.fwd : layer_inputs.backward(grads) - + if evt_idx >= 0 : stop_evt_bwd[evt_idx].record() - + torch.cuda.synchronize() elapsed_time_fwd = 0.0 elapsed_time_bwd = 0.0 for evt_idx in range(0, args.trials) : elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx]) elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx]) - + print("[ {} Attn {} ]Total Tokens: {:4d} Sequences: {:3d} Sequence Length: {:3d} Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format( 'Encdec' if args.encdec_attn else 'Self', \ 'Norm&Add' if args.norm_add else '', \ diff --git a/apex/contrib/fmha/fmha.py b/apex/contrib/fmha/fmha.py index 6aaca804a..8eab54209 100644 --- a/apex/contrib/fmha/fmha.py +++ b/apex/contrib/fmha/fmha.py @@ -1,6 +1,6 @@ ############################################################################### # Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. -# +# # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ # * Neither the name of the NVIDIA CORPORATION nor the # names of its contributors may be used to endorse or promote products # derived from this software without specific prior written permission. -# +# # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -45,7 +45,7 @@ def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors): ctx.max_s = max_s ctx.zero_tensors = zero_tensors return context - + @staticmethod def backward(ctx, dout): qkv, S_dmask = ctx.saved_tensors diff --git a/apex/contrib/groupbn/batch_norm.py b/apex/contrib/groupbn/batch_norm.py index af0b7e9b2..c3eca3735 100644 --- a/apex/contrib/groupbn/batch_norm.py +++ b/apex/contrib/groupbn/batch_norm.py @@ -179,10 +179,10 @@ def __init__(self, num_features, fuse_relu=False, bn_group=1, torch_channels_las #FIXME: turn pair handles into an array if bn_group>1: local_rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() + world_size = torch.distributed.get_world_size() assert(world_size >= bn_group) assert(world_size % bn_group == 0) - + bn_sync_steps = 1 if (bn_group==4): bn_sync_steps = 2 diff --git a/apex/contrib/index_mul_2d/index_mul_2d.py b/apex/contrib/index_mul_2d/index_mul_2d.py index 1d34fe20c..23df5f362 100644 --- a/apex/contrib/index_mul_2d/index_mul_2d.py +++ b/apex/contrib/index_mul_2d/index_mul_2d.py @@ -49,9 +49,9 @@ def forward(ctx, in1: torch.Tensor, in2: torch.Tensor, idx1: torch.Tensor) -> to @staticmethod def backward(ctx, grad_out): - + in1, in2, idx1 = ctx.for_backwards - + grad_in1, grad_in2 = index_mul_2d_backward(in1, in2, idx1, grad_out) return grad_in1, grad_in2, None @@ -93,8 +93,8 @@ def forward(ctx, in1: torch.Tensor, in2: torch.Tensor, idx1: torch.Tensor, grad_out, in1, in2, - idx1) - + idx1) + ctx.for_backwards = (in1, in2, idx1, grad_out) return grad_in1, grad_in2 @@ -104,7 +104,7 @@ def backward(ctx, grad_grad_in1, grad_grad_in2): grad_grad_in1 = grad_grad_in1.contiguous() if not grad_grad_in2.is_contiguous(): grad_grad_in2 = grad_grad_in2.contiguous() - + assert grad_grad_in1.is_contiguous() assert grad_grad_in2.is_contiguous() @@ -135,7 +135,7 @@ def backward(ctx, grad_grad_in1, grad_grad_in2): grad_grad_in2, in1, in2, - idx1) + idx1) return grad_in1, grad_in2, None, grad_grad_out diff --git a/apex/contrib/multihead_attn/README.md b/apex/contrib/multihead_attn/README.md index bf0f3a07f..f1cfbb454 100644 --- a/apex/contrib/multihead_attn/README.md +++ b/apex/contrib/multihead_attn/README.md @@ -1,4 +1,4 @@ -# Fast Multihead Attention +# Fast Multihead Attention This implementation has two main features : * A C++ implementation to avoid the CPU overheads of Pytorch found with smaller batch sizes. diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 65da11218..be03406e2 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -579,10 +579,10 @@ def __init__( "multi_tensor_copy to set dummy_overflow_buf to indicate " "whether there's gradient Inf/NaN, build APEX with " "`--deprecated_fused_adam` is essential.") - + if capturable: raise Exception("Distributed fused adam does not support cudagraph on ROCm") - + # If capturable for CUDA graph self.capturable: bool = capturable # If the optimizer is capturable then LR should be a tensor (on GPU) diff --git a/apex/contrib/optimizers/fp16_optimizer.py b/apex/contrib/optimizers/fp16_optimizer.py index 0cbb63b82..d6d2ce45a 100755 --- a/apex/contrib/optimizers/fp16_optimizer.py +++ b/apex/contrib/optimizers/fp16_optimizer.py @@ -104,7 +104,7 @@ def step(self, closure=None): for i, p in enumerate(group): fp16_grad.append(p.grad) fp16_grads.append(fp16_grad) - + # nan check self.overflow_buf.zero_() for fp16_grad in fp16_grads: diff --git a/apex/contrib/optimizers/fused_sgd.py b/apex/contrib/optimizers/fused_sgd.py index 83587c6a6..951e8a9a9 100644 --- a/apex/contrib/optimizers/fused_sgd.py +++ b/apex/contrib/optimizers/fused_sgd.py @@ -12,8 +12,8 @@ class FusedSGD(Optimizer): * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. :class:`apex.contrib.optimizers.FusedSGD` should be used without AMP. - - :class:`apex.contrib.optimizers.FusedSGD` only works in the case where all parameters require grad. + + :class:`apex.contrib.optimizers.FusedSGD` only works in the case where all parameters require grad. Nesterov momentum is based on the formula from `On the importance of initialization and momentum in deep learning`__. @@ -111,7 +111,7 @@ def get_momentums(self, params): first_run = False momentums.append(param_state['momentum_buffer']) return momentums, first_run - + def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None): """Performs a single optimization step. Arguments: @@ -157,13 +157,13 @@ def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norm else: output_params_group = output_params - for group, grads_this_group, output_params_this_group in zip(self.param_groups, - grads_group, + for group, grads_this_group, output_params_this_group in zip(self.param_groups, + grads_group, output_params_group): - if grads_this_group is None or output_params_this_group is None: + if grads_this_group is None or output_params_this_group is None: raise RuntimeError('apex.contrib.optimizers.FusedSGD only works \ when all parameters require grad.') - + weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] @@ -171,7 +171,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norm lr = group['lr'] first_runs = [True, True] - + # output_params_this_group: original weights (either fp16 or fp32) # group['params']: master weights (fp32) diff --git a/apex/contrib/peer_memory/peer_halo_exchanger_1d.py b/apex/contrib/peer_memory/peer_halo_exchanger_1d.py index cc25693ce..f15b2aee1 100644 --- a/apex/contrib/peer_memory/peer_halo_exchanger_1d.py +++ b/apex/contrib/peer_memory/peer_halo_exchanger_1d.py @@ -59,7 +59,7 @@ def __call__(self, y, H_split=True, explicit_nhwc=False, numSM=1, diagnostics=Fa high_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo] pm.push_pull_halos_1d( diagnostics, explicit_nhwc, numSM, - self.low_zero, low_out_halo, low_tx[self.peer_rank], high_tx[self.low_neighbor], low_inp_halo, + self.low_zero, low_out_halo, low_tx[self.peer_rank], high_tx[self.low_neighbor], low_inp_halo, self.high_zero, high_out_halo, high_tx[self.peer_rank], low_tx[self.high_neighbor], high_inp_halo, self.signals[self.low_neighbor], self.signals[self.high_neighbor], self.signals[self.peer_rank] ) diff --git a/apex/contrib/sparsity/README.md b/apex/contrib/sparsity/README.md index 34681188d..19bde3f26 100644 --- a/apex/contrib/sparsity/README.md +++ b/apex/contrib/sparsity/README.md @@ -31,7 +31,7 @@ for epoch in range(epochs): torch.save(...) ``` -The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. +The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. ## Generate a Sparse Network @@ -51,7 +51,7 @@ criterion = ... # compare ground truth with model predition; use the same criter optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model lr_scheduler = ... # learning rate scheduler; use the same schedule as used to generate the dense trained model -from apex.contrib.sparsity import ASP +from apex.contrib.sparsity import ASP ASP.prune_trained_model(model, optimizer) #pruned a trained model x, y = DataLoader(args) @@ -62,7 +62,7 @@ for epoch in range(epochs): # train the pruned model for the same number of epoc loss.backward() optimizer.step() -torch.save(...) # saves the pruned checkpoint with sparsity masks +torch.save(...) # saves the pruned checkpoint with sparsity masks ``` ## Non-Standard Usage @@ -73,7 +73,7 @@ If your goal is to easily perpare a network for accelerated inference, please fo ASP.compute_sparse_masks() ``` -A more thorough example can be found in `./test/toy_problem.py`. +A more thorough example can be found in `./test/toy_problem.py`. ## Advanced Usage: Channel Permutation diff --git a/apex/contrib/sparsity/asp.py b/apex/contrib/sparsity/asp.py index 924024f08..19affa592 100644 --- a/apex/contrib/sparsity/asp.py +++ b/apex/contrib/sparsity/asp.py @@ -39,13 +39,13 @@ class ASP: @classmethod def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d", verbosity=3, - whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d], + whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d], allowed_layer_names=None, disallowed_layer_names=[], allow_recompute_mask=False, custom_layer_dict={}, allow_permutation=True): """Call this method to modify your model to take advantage of sparse matrix multiplication. Note that this call alone only augments the model with additional buffers needed for sparse MMA, - it does not enable use of sparse MMA. + it does not enable use of sparse MMA. If you are starting with a fresh model: @@ -76,7 +76,7 @@ def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d", Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage. custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']} allow_permutation If True, allow the input channel permutation to ease the influence of weight pruning. - + [Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe. """ assert (cls.__model is None), "ASP has been initialized already." @@ -91,7 +91,7 @@ def create_mask_from_pattern(param): else: cls.__calculate_mask = mask_calculator #user defined function - # function to extract variables that will be sparsified. + # function to extract variables that will be sparsified. # idea is that you will add one of these functions for each module type that can be sparsified. if torchvision_imported: print("[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.") @@ -139,10 +139,10 @@ def add_sparse_attributes(module_name, module): if p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #For Conv2d dim= K x CRS; we prune along C print("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype))) continue - + if cls.__verbosity >= 3: print("[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype))) - + mask = torch.ones_like(p).bool() buffname = p_name.split(".")[-1] # buffer names cannot contain "." module.register_buffer('__%s_mma_mask' % buffname, mask) @@ -288,7 +288,7 @@ def is_sparsity_enabled(cls): return False elif total == sp50: return True - + @classmethod def prune_trained_model(cls, model, optimizer): # add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks) diff --git a/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu b/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu index c7b053cc5..df6107a96 100644 --- a/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu +++ b/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu @@ -6,7 +6,7 @@ namespace py = pybind11; #define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) { - if (code != cudaSuccess) + if (code != cudaSuccess) { fprintf(stderr,"GPUassert %d: %s %s %d\n", (int)code, cudaGetErrorString(code), file, line); if (abort) exit(code); @@ -336,7 +336,7 @@ int run_build_permute_map(py::array_t& py_matrix, gpuErrchk(cudaMemcpy( d_permutations, permutations, num_permutations*perm_length*sizeof(unsigned int), cudaMemcpyHostToDevice )); unsigned int group_offset = 0; - for (unsigned int l = 0; l < launches; ++l) + for (unsigned int l = 0; l < launches; ++l) { unsigned int groups_this_launch = (l < full_launches) ? MAX_GROUPS_PER_LAUNCH : final_launch; diff --git a/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py b/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py index 55a18e186..dce919bf4 100644 --- a/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py +++ b/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py @@ -15,7 +15,7 @@ print(f"Could not find permutation search CUDA kernels, falling back to CPU path") kernels_found = False -def use_gpu(initial_override = True): +def use_gpu(initial_override = True): global gpus_tested, gpus_found, kernels_found if not gpus_tested: if not initial_override: @@ -30,7 +30,7 @@ def use_gpu(initial_override = True): print(f"Could not find nvidia-smi, please check your cuda installation") gpus_tested = True - + return gpus_found > 0 and kernels_found ############################################################################################## @@ -83,7 +83,7 @@ def sum_after_2_to_4(matrix): def try_swap(matrix, dst, src): src_base = sum_after_2_to_4(matrix[...,int(src/4)*4:int(src/4)*4+4]) dst_base = sum_after_2_to_4(matrix[...,int(dst/4)*4:int(dst/4)*4+4]) - + # swap matrix[...,[src,dst]] = matrix[...,[dst,src]] @@ -93,7 +93,7 @@ def try_swap(matrix, dst, src): # swap back matrix[...,[src,dst]] = matrix[...,[dst,src]] - + return src_sum + dst_sum, (src_sum + dst_sum) - (src_base + dst_base) ############################################################################################## diff --git a/apex/contrib/sparsity/sparse_masklib.py b/apex/contrib/sparsity/sparse_masklib.py index 48deb633c..26bd193a2 100644 --- a/apex/contrib/sparsity/sparse_masklib.py +++ b/apex/contrib/sparsity/sparse_masklib.py @@ -53,13 +53,13 @@ def m4n2_1d(mat, density): Below 2d-masking related code is targeted more for training (from scratch). 2d-pruning of a weight tensor is done to accelerate DGRAD step during backprop phase of training algorithm. Acceleration comes from using SpMMA instructions in - Tensor Cores of NVIDIA Ampere GPU Architecture + Tensor Cores of NVIDIA Ampere GPU Architecture (note: this code does not do the acceleration, GPU kernels are required for this). 1d pruning of weight tensor helps speed up FPROP step by pruning in 2:4 pattern along the horizontal (logical) direction. During DGRAD step, weight tensor is transposed. 2d pruning functions below, mask weight tensor such that their transposed versions are also 2:4 sparse along the - horizontal (logical) direction. Thus, with 2d pruning, weight tensors are + horizontal (logical) direction. Thus, with 2d pruning, weight tensors are 2:4 sparse along row and column directions. """ @@ -179,6 +179,6 @@ def create_mask(tensor, pattern="m4n2_1d", density=0.5): t = t.permute(2,3,0,1).contiguous().view(shape[2]*shape[3]*shape[0], shape[1]) func = getattr(sys.modules[__name__], pattern, None) mask = func(t, density) - mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2,3,0,1).contiguous() + mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2,3,0,1).contiguous() return mask.view(shape).type(ttype) diff --git a/apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py b/apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py index f2a4492d2..f71f53112 100644 --- a/apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py +++ b/apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py @@ -19,7 +19,7 @@ class FusedDenseTest(unittest.TestCase): def setUp(self, seed=0): torch.manual_seed(seed) - + self.batch_size = random.randint(1, 64) self.in_channels = random.randint(1, 64) * 8 self.out_channels = random.randint(1, 64) * 8 @@ -106,26 +106,26 @@ def test_conv_bias_retinanet(self): in_channels = 256 out_channels = 2376 h, w = 100, 100 - + # Input in NHWC format with HALF precision x = torch.randn(batch_size, in_channels, h, w).cuda()\ .to(memory_format=torch.channels_last).half() x_ = x.clone() x.requires_grad_() x_.requires_grad_() - + # Conv layer - conv = torch.nn.Conv2d(in_channels, out_channels, 3, + conv = torch.nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1).cuda()\ .to(memory_format=torch.channels_last) conv_ = copy.deepcopy(conv) - + # Test with FP16 with torch.amp.autocast(device_type="cuda", dtype=torch.half): out = ConvBias(x, conv.weight, conv.bias.reshape(1, -1, 1, 1), 1, 1) loss = (out.float()**2).sum() / out.numel() loss.backward() - + # Reference with FP16 with torch.amp.autocast(device_type="cuda", dtype=torch.half): out_ = conv_(x_) diff --git a/apex/contrib/test/fmha/test_fmha.py b/apex/contrib/test/fmha/test_fmha.py index 00970eed2..d9d03d7de 100644 --- a/apex/contrib/test/fmha/test_fmha.py +++ b/apex/contrib/test/fmha/test_fmha.py @@ -1,6 +1,6 @@ ############################################################################### # Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. -# +# # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # * Redistributions of source code must retain the above copyright @@ -11,7 +11,7 @@ # * Neither the name of the NVIDIA CORPORATION nor the # names of its contributors may be used to endorse or promote products # derived from this software without specific prior written permission. -# +# # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -64,44 +64,44 @@ def run_test(self, s: int, b: int, zero_tensors: bool): torch.manual_seed(1234) torch.cuda.manual_seed(1234) - + dtype = torch.float16 device = torch.device('cuda') - h = 16 + h = 16 d = 64 - - slens = [s] * b + + slens = [s] * b a = torch.tensor(np.array([0] + slens), dtype=torch.int32) amask = torch.ones(b,h,s,s, dtype=dtype, device=device) seqlens = torch.tensor(slens, dtype=torch.int32, device=device) cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device) total = cu_seqlens[-1].item() - + qkv = torch.randn((b,s,h,3,d), device=device, dtype=dtype) - + qkv_vs = qkv.permute(0,1,3,2,4).contiguous().view(b*s, 3, h,d) - + qkv.requires_grad = True - + if b < 4: ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, True, zero_tensors, None) else: ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, False, zero_tensors, None) ctx = ctx.view(b,s,h,d) - + ctx_ref = py_mha(qkv, amask, b,s,h,d) self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3)) - + labels = torch.randn_like(ctx_ref) diff = ctx_ref - labels l = (diff * diff).sum() / b l.backward() - - dw = ctx_ref.grad.permute(0,2,1,3) - + + dw = ctx_ref.grad.permute(0,2,1,3) + dw2 = dw.permute(0,2,1,3).clone().detach().contiguous() - + if b < 4: dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors) else: diff --git a/apex/contrib/test/groupbn/test_groupbn_channel_last.py b/apex/contrib/test/groupbn/test_groupbn_channel_last.py index 5ae36e33a..f1c317268 100644 --- a/apex/contrib/test/groupbn/test_groupbn_channel_last.py +++ b/apex/contrib/test/groupbn/test_groupbn_channel_last.py @@ -35,13 +35,13 @@ def bn_nhwc_bwd_ref(grad_y, x, mu, ivar, gamma): grad_y = grad_y.permute(0, 2, 3, 1).contiguous() x = x.permute(0, 2, 3, 1).contiguous() sum_dim_c = (0, 1, 2) - grad_y_f32 = grad_y.float() + grad_y_f32 = grad_y.float() x_f32 = x.float() N = x.shape[0] * x.shape[1] * x.shape[2] # nhw ones = torch.ones(x.shape, dtype=torch.float32, device='cuda') xmu = x_f32 - mu - + xhat = xmu * ivar dbias = torch.sum(grad_y_f32, dim=sum_dim_c) diff --git a/apex/contrib/test/index_mul_2d/test_index_mul_2d.py b/apex/contrib/test/index_mul_2d/test_index_mul_2d.py index d8f37ea3c..e3c4e2bdb 100644 --- a/apex/contrib/test/index_mul_2d/test_index_mul_2d.py +++ b/apex/contrib/test/index_mul_2d/test_index_mul_2d.py @@ -17,7 +17,7 @@ class IndexMul2dTest(unittest.TestCase): def setUp(self, seed=0): torch.manual_seed(seed) - + self.input1_size = random.randint(1, 1000) self.input2_size = random.randint(1, 100000) self.feature_size = random.randint(1, 256) @@ -52,7 +52,7 @@ def test_index_mul_float(self): energy, self.input1_float, grad_outputs=torch.ones_like(energy), - create_graph=True, + create_graph=True, )[0] loss = (out.float()**2).sum() / out.numel() + (force.float()**2).sum() loss.backward() @@ -63,7 +63,7 @@ def test_index_mul_float(self): energy_, self.input1_float_, grad_outputs=torch.ones_like(energy), - create_graph=True, + create_graph=True, )[0] loss = (out_.float()**2).sum() / out_.numel() + (force_.float()**2).sum() loss.backward() @@ -80,7 +80,7 @@ def test_index_mul_half(self): energy, self.input1_half, grad_outputs=torch.ones_like(energy), - create_graph=True, + create_graph=True, )[0] loss = (out.float()**2).sum() / out.numel() + (force.float()**2).sum() loss.backward() @@ -91,11 +91,11 @@ def test_index_mul_half(self): energy_, self.input1_half_, grad_outputs=torch.ones_like(energy), - create_graph=True, + create_graph=True, )[0] loss = (out_.float()**2).sum() / out_.numel() + (force_.float()**2).sum() loss.backward() - + self.assertTrue(torch.allclose(self.input1_half, self.input1_half_, atol=1e-3, rtol=1e-3, equal_nan=True)) self.assertTrue(torch.allclose(self.input2_half, self.input2_half_, atol=1e-3, rtol=1e-3, equal_nan=True)) self.assertTrue(torch.allclose(self.input1_half.grad, self.input1_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) diff --git a/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py b/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py index 836fe8433..4d64d351b 100644 --- a/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py +++ b/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py @@ -15,35 +15,35 @@ def setUp(self, seed=1234): self.heads = 16 self.dropout_prob = 0.0 - self.ref_layer = EncdecMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=False, + self.ref_layer = EncdecMultiheadAttn(self.hidden_dim, + self.heads, + dropout=self.dropout_prob, + bias=False, + include_norm_add=False, impl='default') self.ref_layer.cuda().half() self.ref_layer.reset_parameters() - self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) # Reset seed so parameters are identical torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - - self.tst_layer = EncdecMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=False, + + self.tst_layer = EncdecMultiheadAttn(self.hidden_dim, + self.heads, + dropout=self.dropout_prob, + bias=False, + include_norm_add=False, impl='fast') self.tst_layer.cuda().half() self.tst_layer.reset_parameters() - self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) def test_encdec_multihead_attn(self) : @@ -72,28 +72,28 @@ def test_encdec_multihead_attn(self) : self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3)) - + def test_encdec_multihead_attn_time_mask(self) : grads = torch.randn_like(self.tst_inputs_q) time_mask_byte = torch.triu(torch.ones(self.tst_inputs_q.size(0), self.tst_inputs_k.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1) time_mask_bool = time_mask_byte.to(torch.bool) - - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, - self.ref_inputs_k, + + ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, + self.ref_inputs_k, self.ref_inputs_k, - key_padding_mask=None, - need_weights=False, + key_padding_mask=None, + need_weights=False, attn_mask=time_mask_bool, is_training=True) - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, - self.tst_inputs_k, + tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, + self.tst_inputs_k, self.tst_inputs_k, - key_padding_mask=None, - need_weights=False, + key_padding_mask=None, + need_weights=False, attn_mask=time_mask_byte, is_training=True) - + self.ref_inputs_q.backward(grads) self.tst_inputs_q.backward(grads) @@ -101,28 +101,28 @@ def test_encdec_multihead_attn_time_mask(self) : self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3)) - + def test_encdec_multihead_attn_pad_mask(self) : grads = torch.randn_like(self.tst_inputs_q) pad_mask_byte = torch.tril(torch.ones(self.tst_inputs_k.size(1), self.tst_inputs_k.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1) pad_mask_bool = pad_mask_byte.to(torch.bool) - - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, - self.ref_inputs_k, + + ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, + self.ref_inputs_k, self.ref_inputs_k, - key_padding_mask=pad_mask_bool, - need_weights=False, + key_padding_mask=pad_mask_bool, + need_weights=False, attn_mask=None, is_training=True) - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, - self.tst_inputs_k, + tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, + self.tst_inputs_k, self.tst_inputs_k, - key_padding_mask=pad_mask_byte, - need_weights=False, + key_padding_mask=pad_mask_byte, + need_weights=False, attn_mask=None, is_training=True) - + self.ref_inputs_q.backward(grads) self.tst_inputs_q.backward(grads) diff --git a/apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py b/apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py index 2ab3009c2..9fd7460c4 100644 --- a/apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py +++ b/apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py @@ -15,57 +15,57 @@ def setUp(self, seed=1234): self.heads = 16 self.dropout_prob = 0.0 - self.ref_layer = EncdecMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=True, + self.ref_layer = EncdecMultiheadAttn(self.hidden_dim, + self.heads, + dropout=self.dropout_prob, + bias=False, + include_norm_add=True, impl='default') self.ref_layer.cuda().half() self.ref_layer.reset_parameters() - self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) # Reset seed so parameters are identical torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - - self.tst_layer = EncdecMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=True, + + self.tst_layer = EncdecMultiheadAttn(self.hidden_dim, + self.heads, + dropout=self.dropout_prob, + bias=False, + include_norm_add=True, impl='fast') self.tst_layer.cuda().half() self.tst_layer.reset_parameters() - - self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + + self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) def test_encdec_multihead_attn_norm_add(self) : grads = torch.randn_like(self.tst_inputs_q) - + for _ in range(5) : - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, - self.ref_inputs_k, + ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, + self.ref_inputs_k, self.ref_inputs_k, - key_padding_mask=None, - need_weights=False, + key_padding_mask=None, + need_weights=False, attn_mask=None, is_training=True) - - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, - self.tst_inputs_k, + + tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, + self.tst_inputs_k, self.tst_inputs_k, - key_padding_mask=None, - need_weights=False, + key_padding_mask=None, + need_weights=False, attn_mask=None, is_training=True) - + self.ref_inputs_q.backward(grads) self.tst_inputs_q.backward(grads) diff --git a/apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py b/apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py index b4bbf342b..b4a69711d 100644 --- a/apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py +++ b/apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py @@ -15,57 +15,57 @@ def setUp(self, seed=1234): self.heads = 16 self.dropout_prob = 0.0 - self.ref_layer = SelfMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=True, - include_norm_add=False, - separate_qkv_params=True, - mask_additive=True, + self.ref_layer = SelfMultiheadAttn(self.hidden_dim, + self.heads, + dropout=self.dropout_prob, + bias=True, + include_norm_add=False, + separate_qkv_params=True, + mask_additive=True, impl='default') self.ref_layer.cuda().half() self.ref_layer.reset_parameters() - self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) # Reset seed so parameters are identical torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - - self.tst_layer = SelfMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=True, - include_norm_add=False, - separate_qkv_params=True, - mask_additive=True, + + self.tst_layer = SelfMultiheadAttn(self.hidden_dim, + self.heads, + dropout=self.dropout_prob, + bias=True, + include_norm_add=False, + separate_qkv_params=True, + mask_additive=True, impl='fast') self.tst_layer.cuda().half() self.tst_layer.reset_parameters() - - self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + + self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - + def test_self_multihead_attn_additive_mask(self) : grads = torch.randn_like(self.tst_inputs) mask = ((torch.randn(self.sequences, self.seq_length) > 0) * -10000.0).half().cuda() - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, - self.ref_inputs, + ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, self.ref_inputs, - key_padding_mask=mask, - need_weights=False, + self.ref_inputs, + key_padding_mask=mask, + need_weights=False, attn_mask=None, is_training=True) - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, - self.tst_inputs, + tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, + self.tst_inputs, self.tst_inputs, - key_padding_mask=mask, - need_weights=False, + key_padding_mask=mask, + need_weights=False, attn_mask=None, is_training=True) - + self.ref_inputs.backward(grads) self.tst_inputs.backward(grads) diff --git a/apex/contrib/test/multihead_attn/test_mha_fused_softmax.py b/apex/contrib/test/multihead_attn/test_mha_fused_softmax.py index 60d954184..5a064b942 100644 --- a/apex/contrib/test/multihead_attn/test_mha_fused_softmax.py +++ b/apex/contrib/test/multihead_attn/test_mha_fused_softmax.py @@ -16,20 +16,20 @@ def setUp(self, seed=1234): self.mask = (torch.randn(self.sequences,self.seq_length)>0).cuda() self.mask = self.mask.half()*-10000 - self.ref_inputs = torch.randn(self.heads * self.sequences, self.seq_length, self.seq_length, + self.ref_inputs = torch.randn(self.heads * self.sequences, self.seq_length, self.seq_length, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - + self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True) def test_fused_softmax(self) : grads = torch.randn_like(self.tst_inputs) y_ref = self.ref_inputs.view(self.sequences, self.heads, self.seq_length, self.seq_length) y_ref = y_ref + self.mask.unsqueeze(1).unsqueeze(2) - y_ref = y_ref.view(self.sequences*self.heads, self.seq_length, self.seq_length) + y_ref = y_ref.view(self.sequences*self.heads, self.seq_length, self.seq_length) y_ref = F.softmax(y_ref, dim=-1) - y_ref = torch._fused_dropout(y_ref, 1.0) - - y_tst = fast_mask_softmax_dropout_func(True, self.heads, self.tst_inputs, self.mask, True, 0.0) + y_ref = torch._fused_dropout(y_ref, 1.0) + + y_tst = fast_mask_softmax_dropout_func(True, self.heads, self.tst_inputs, self.mask, True, 0.0) y_ref[0].backward(grads) y_tst.backward(grads) diff --git a/apex/contrib/test/multihead_attn/test_self_multihead_attn.py b/apex/contrib/test/multihead_attn/test_self_multihead_attn.py index 10d779feb..8bd1a8e04 100644 --- a/apex/contrib/test/multihead_attn/test_self_multihead_attn.py +++ b/apex/contrib/test/multihead_attn/test_self_multihead_attn.py @@ -96,7 +96,7 @@ def test_self_multihead_attn_time_mask(self) : self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)) - + def test_self_multihead_attn_pad_mask(self) : grads = torch.randn_like(self.tst_inputs) pad_mask_byte = torch.tril(torch.ones(self.tst_inputs.size(1), self.tst_inputs.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1) diff --git a/apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py b/apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py index 125656fc6..f0fade261 100644 --- a/apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py +++ b/apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py @@ -15,53 +15,53 @@ def setUp(self, seed=1234): self.heads = 16 self.dropout_prob = 0.0 - self.ref_layer = SelfMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=True, + self.ref_layer = SelfMultiheadAttn(self.hidden_dim, + self.heads, + dropout=self.dropout_prob, + bias=False, + include_norm_add=True, impl='default') self.ref_layer.cuda().half() self.ref_layer.reset_parameters() - self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) # Reset seed so parameters are identical torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - - self.tst_layer = SelfMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=True, + + self.tst_layer = SelfMultiheadAttn(self.hidden_dim, + self.heads, + dropout=self.dropout_prob, + bias=False, + include_norm_add=True, impl='fast') self.tst_layer.cuda().half() self.tst_layer.reset_parameters() - - self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, + + self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) def test_self_multihead_attn_norm_add(self) : grads = torch.randn_like(self.tst_inputs) for _ in range(0, 5) : - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, - self.ref_inputs, + ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, + self.ref_inputs, self.ref_inputs, - key_padding_mask=None, - need_weights=False, + key_padding_mask=None, + need_weights=False, attn_mask=None, is_training=True) - - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, - self.tst_inputs, + + tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, + self.tst_inputs, self.tst_inputs, - key_padding_mask=None, - need_weights=False, + key_padding_mask=None, + need_weights=False, attn_mask=None, is_training=True) - + self.ref_inputs.backward(grads) self.tst_inputs.backward(grads) diff --git a/apex/contrib/test/optimizers/test_distributed_fused_lamb.py b/apex/contrib/test/optimizers/test_distributed_fused_lamb.py index d8f56117a..688ec2bef 100644 --- a/apex/contrib/test/optimizers/test_distributed_fused_lamb.py +++ b/apex/contrib/test/optimizers/test_distributed_fused_lamb.py @@ -26,7 +26,7 @@ def forward(self, input_tensor, gt): return loss # A test for distributed fused Lamb optimizer: run several iterations and see if loss decreases -# There are two instances of the same test because based on `world_size` the optimizer decides what collectives operation to use. +# There are two instances of the same test because based on `world_size` the optimizer decides what collectives operation to use. # If torch.distributed.get_world_size() == torch.cuda.device_count() it uses only `all_gather`. # If torch.distributed.get_world_size() < torch.cuda.device_count() it uses both `all_gather` and `reduce_scatter`. class NcclDistributedFusedLAMB(NcclDistributedTestBase): @@ -69,10 +69,10 @@ def test_distributed_fused_lamb(self, no_copy, opt_kwargs): if 'full_ar' not in opt_kwargs: opt_kwargs['full_ar'] = gpu_count == torch.cuda.device_count() - # Aidyn-A: not sure what parameters are the best for testing purposes, - # setting up whatever I think appropriate. + # Aidyn-A: not sure what parameters are the best for testing purposes, + # setting up whatever I think appropriate. optimizer = DistributedFusedLAMB( - optimizer_grouped_parameters, + optimizer_grouped_parameters, lr=0.1, betas=(0.9, 0.9), eps=1e-6, diff --git a/apex/contrib/test/test_label_smoothing.py b/apex/contrib/test/test_label_smoothing.py index 70e9f3d09..b978b916f 100644 --- a/apex/contrib/test/test_label_smoothing.py +++ b/apex/contrib/test/test_label_smoothing.py @@ -64,13 +64,13 @@ def test_label_smoothing_function(self): for i in range(iters): logits, labels, half_to_float = self.gen_test_inputs( N, T, H, smoothing, padding_idx) - + # Run original softmax cross entropy with label smoothing logits.grad = None losses = label_smoothing_raw(logits, labels, padding_idx, smoothing) loss = losses.sum() loss.backward() - + ref_loss = loss.clone().detach() ref_grad = logits.grad.clone().detach() @@ -98,7 +98,7 @@ def test_label_smoothing_perf(self): logits, labels, half_to_float = self.gen_test_inputs( N, T, H, smoothing, padding_idx) - + # Run original softmax cross entropy with label smoothing torch.cuda.synchronize() ts = time.time() @@ -110,7 +110,7 @@ def test_label_smoothing_perf(self): torch.cuda.synchronize() print("Raw time {:.2f} s elapsed for {} iterations, norm {:.4f}".format( time.time() - ts, iters, logits.grad.norm())) - + # Run optimized softmax cross entropy with label smoothing torch.cuda.synchronize() ts = time.time() diff --git a/apex/contrib/test/transducer/test_transducer_joint.py b/apex/contrib/test/transducer/test_transducer_joint.py index 3a19482db..5761d4866 100755 --- a/apex/contrib/test/transducer/test_transducer_joint.py +++ b/apex/contrib/test/transducer/test_transducer_joint.py @@ -24,19 +24,19 @@ def gen_input(self, for_vector_kernel): self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device) self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device) self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device) - self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) + self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device) self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max self.dropout_prob = 0.5 - # Make sure gradients from out-of-bound locations are zero. This should be guaranteed by + # Make sure gradients from out-of-bound locations are zero. This should be guaranteed by # the loss function for b in range(self.B): self.h_grad[b, self.f_len[b]:, :, :] = 0 self.h_grad[b, :, self.g_len[b]:, :] = 0 self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len) - + def _pack(self, x, f_len, g_len): B = x.size(0) @@ -60,10 +60,10 @@ def _unpack(self, x, f_len, g_len): my_f_len = f_len[b] my_g_len = g_len[b] for t in range(my_f_len): - x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len : + x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len : my_batch_offset + t*my_g_len + my_g_len] return x_unpacked - + def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout): self.gen_input(for_vector_kernel=for_vector_kernel) # Generate reference @@ -71,24 +71,24 @@ def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout): g_ref = self.g_tst.data.clone() f_ref.requires_grad = True g_ref.requires_grad = True - - my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout, + + my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout, dropout_prob=self.dropout_prob, probe_mask=True) if not pack_output: - h_tst = my_joint( f=self.f_tst, - g=self.g_tst, - f_len=self.f_len, + h_tst = my_joint( f=self.f_tst, + g=self.g_tst, + f_len=self.f_len, g_len=self.g_len) h_tst.backward(self.h_grad) if dropout: mask = my_joint.mask_probe[0] else: batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0) - h_tst = my_joint( f=self.f_tst, - g=self.g_tst, - f_len=self.f_len, - g_len=self.g_len, - batch_offset=batch_offset, + h_tst = my_joint( f=self.f_tst, + g=self.g_tst, + f_len=self.f_len, + g_len=self.g_len, + batch_offset=batch_offset, packed_batch=batch_offset[-1]) h_tst.backward(self.h_grad_packed) if dropout: @@ -97,20 +97,20 @@ def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout): # reference h_ref, f_grad_ref, g_grad_ref \ - = transducer_ref.transducer_joint_reference(f=f_ref, - g=g_ref, - h_grad=self.h_grad, - f_len=self.f_len, - g_len=self.g_len, + = transducer_ref.transducer_joint_reference(f=f_ref, + g=g_ref, + h_grad=self.h_grad, + f_len=self.f_len, + g_len=self.g_len, pack_output=pack_output, relu=relu, dropout=dropout, dropout_prob=self.dropout_prob, mask=mask if dropout else None) - + f_grad_tst = self.f_tst.grad g_grad_tst = self.g_tst.grad - + self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4)) diff --git a/apex/contrib/test/transducer/test_transducer_loss.py b/apex/contrib/test/transducer/test_transducer_loss.py index 82f5bd330..879ade2a8 100755 --- a/apex/contrib/test/transducer/test_transducer_loss.py +++ b/apex/contrib/test/transducer/test_transducer_loss.py @@ -18,10 +18,10 @@ def gen_input(self, scalar_t, for_vector_kernel): self.blank_idx = V - 1 device = "cuda" - self.x_tst = torch.randn((self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True, + self.x_tst = torch.randn((self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True, device=device) self.y = torch.randint(0, self.blank_idx, (self.B, U_max-1), dtype=torch.int, device=device) - self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) + self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) self.y_len = torch.randint(U_min-1, U_max, (self.B,), dtype=torch.int, device=device) self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max-1 @@ -31,11 +31,11 @@ def gen_input(self, scalar_t, for_vector_kernel): x_ref.requires_grad = True loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device)/x_ref.size(0) _, _, self.grad_ref, self.loss_ref \ - = transducer_ref.transducer_loss_reference( x=x_ref, - label=self.y, - f_len=self.f_len, - y_len=self.y_len, - blank_idx=self.blank_idx, + = transducer_ref.transducer_loss_reference( x=x_ref, + label=self.y, + f_len=self.f_len, + y_len=self.y_len, + blank_idx=self.blank_idx, loss_grad=loss_grad) def _pack(self, x): @@ -50,7 +50,7 @@ def _pack(self, x): return x_packed, batch_offset def _unpack(self, x): - x_unpacked = torch.zeros(self.B, self.f_len.max(), self.y_len.max()+1, x.size(-1), + x_unpacked = torch.zeros(self.B, self.f_len.max(), self.y_len.max()+1, x.size(-1), dtype=x.dtype, device=x.device) for b in range(self.B): my_batch_offset = 0 if b == 0 else self.batch_offset[b-1] @@ -63,28 +63,28 @@ def _unpack(self, x): def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel): self.gen_input(scalar_t, for_vector_kernel) - my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward, - packed_input=packed_input) + my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward, + packed_input=packed_input) if not packed_input: loss_tst = my_loss( x=self.x_tst, - label=self.y, - f_len=self.f_len, - y_len=self.y_len, + label=self.y, + f_len=self.f_len, + y_len=self.y_len, blank_idx=self.blank_idx) - loss_tst.mean().backward() + loss_tst.mean().backward() grad_tst = self.x_tst.grad else: loss_tst = my_loss( x=self.x_tst_packed, - label=self.y, - f_len=self.f_len, - y_len=self.y_len, + label=self.y, + f_len=self.f_len, + y_len=self.y_len, blank_idx=self.blank_idx, - batch_offset=self.batch_offset, + batch_offset=self.batch_offset, max_f_len=max(self.f_len)) loss_tst.mean().backward() grad_tst_packed = self.x_tst_packed.grad grad_tst = self._unpack(grad_tst_packed) - + return loss_tst, grad_tst def test_transducer_loss_fp32(self): diff --git a/apex/contrib/test/transducer/transducer_ref.py b/apex/contrib/test/transducer/transducer_ref.py index de342798e..eccebb0cf 100755 --- a/apex/contrib/test/transducer/transducer_ref.py +++ b/apex/contrib/test/transducer/transducer_ref.py @@ -23,7 +23,7 @@ def forward_alpha(x, label, f_len, y_len, blank_idx): for u in range(1, y_len[b]+1): curr_ = alpha[b, t-1, u] + x[b, t-1, u, blank_idx] next_ = alpha[b, t, u-1] + x[b, t, u-1, label[b, u-1]] - alpha[b, t, u] = log_sum_exp(curr_, next_) + alpha[b, t, u] = log_sum_exp(curr_, next_) return alpha def forward_beta(x, label, f_len, y_len, blank_idx): @@ -33,14 +33,14 @@ def forward_beta(x, label, f_len, y_len, blank_idx): for b in range(B): beta[b, f_len[b]-1, y_len[b]] = x[b, f_len[b]-1, y_len[b], blank_idx] for t in range(f_len[b]-2, -1, -1): - beta[b, t, y_len[b]] = beta[b, t+1, y_len[b]] + x[b, t, y_len[b], blank_idx] + beta[b, t, y_len[b]] = beta[b, t+1, y_len[b]] + x[b, t, y_len[b], blank_idx] for u in range(y_len[b]-1, -1, -1): beta[b, f_len[b]-1, u] = beta[b, f_len[b]-1, u+1] + x[b, f_len[b]-1, u, label[b, u]] for t in range(f_len[b]-2, -1, -1): for u in range(y_len[b]-1, -1, -1): - curr_ = beta[b, t+1, u] + x[b, t, u, blank_idx] + curr_ = beta[b, t+1, u] + x[b, t, u, blank_idx] next_ = beta[b, t, u+1] + x[b, t, u, label[b, u]] - beta[b, t, u] = log_sum_exp(curr_, next_) + beta[b, t, u] = log_sum_exp(curr_, next_) return beta def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx): @@ -50,25 +50,25 @@ def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx): common_factor = torch.log(loss_grad[b]) + alpha - beta[b, 0, 0] # next for u in range(y_len[b]): - grad[b, :f_len[b], u, label[b, u]] = -torch.exp(common_factor[b, :f_len[b], u] - + beta[b, :f_len[b], u+1] + grad[b, :f_len[b], u, label[b, u]] = -torch.exp(common_factor[b, :f_len[b], u] + + beta[b, :f_len[b], u+1] + x[b, :f_len[b], u, label[b, u]]) # current grad[b, :f_len[b]-1, :y_len[b]+1, blank_idx] \ - = -torch.exp(common_factor[b, :f_len[b]-1, :y_len[b]+1] - + beta[b, 1:f_len[b], :y_len[b]+1] + = -torch.exp(common_factor[b, :f_len[b]-1, :y_len[b]+1] + + beta[b, 1:f_len[b], :y_len[b]+1] + x[b, :f_len[b]-1, :y_len[b]+1, blank_idx]) grad[b, f_len[b]-1, y_len[b], blank_idx] = -torch.exp(common_factor[b, f_len[b]-1, y_len[b]] + x[b, f_len[b]-1, y_len[b], blank_idx]) - + return grad x_log = torch.nn.functional.log_softmax(x, dim=-1) alpha = forward_alpha(x_log, label, f_len, y_len, blank_idx) beta = forward_beta(x_log, label, f_len, y_len, blank_idx) - grad = backward(x_log, label, f_len, y_len, alpha, beta, + grad = backward(x_log, label, f_len, y_len, alpha, beta, loss_grad, blank_idx) x_log.backward(grad) loss = -beta[:, 0, 0] @@ -76,7 +76,7 @@ def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx): return alpha, beta, x.grad, loss -def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dropout, +def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dropout, dropout_prob=0, mask=None): if dropout and mask == None: raise NotImplementedError("mask needs to supplied to test dropout.") @@ -100,7 +100,7 @@ def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dr h[b, f_len[b]:] = -1 h[b, :, g_len[b]:] = -1 - return h, f.grad, g.grad + return h, f.grad, g.grad # packing list_to_pack = [] diff --git a/apex/contrib/transducer/transducer.py b/apex/contrib/transducer/transducer.py index 784396275..cd92d3890 100755 --- a/apex/contrib/transducer/transducer.py +++ b/apex/contrib/transducer/transducer.py @@ -4,29 +4,29 @@ class TransducerJoint(torch.nn.Module): """Transducer joint - Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural + Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural Networks Arguments: - pack_output (bool, optional): whether to pack the output in a compact form with don't-care + pack_output (bool, optional): whether to pack the output in a compact form with don't-care data being removed. (default: False) - relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1 + relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1 (default: False) - dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1 + dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1 (default: False) - opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm. + opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm. (default: 1) - fwd_tile_size (int, optional): tile size used in forward operation. This argument will be - ignored if opt != 1. (default: 4) + fwd_tile_size (int, optional): tile size used in forward operation. This argument will be + ignored if opt != 1. (default: 4) dropout_prob (float, optional): dropout probability. (default: 0.0) probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout - operation. When this argument is set to True, the mask can be accessed through + operation. When this argument is set to True, the mask can be accessed through self.mask_probe. (default: false) """ - def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4, + def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4, dropout_prob=0, probe_mask=False): - super(TransducerJoint, self).__init__() + super(TransducerJoint, self).__init__() self.pack_output = pack_output self.relu = relu self.dropout = dropout @@ -49,44 +49,44 @@ def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0): f_len (tensor): length of transcription vector for each batch. g_len (tensor): length of prediction vector minus 1 for each batch. batch_offset (tensor, optional): tensor containing the offset of each batch - in the results. For example, batch offset can be obtained from: + in the results. For example, batch offset can be obtained from: batch_offset = torch.cumsum(f_len*g_len, dim=0) - This argument is required if pack_output == True, and is ignored if + This argument is required if pack_output == True, and is ignored if pack_output == False. (default: None) - packed_batch (int, optional): the batch size after packing. This argument is + packed_batch (int, optional): the batch size after packing. This argument is ignored if pack_output == False. (default: 0) """ my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset if self.pack_output and (batch_offset is None or packed_batch == 0): raise Exception("Please specify batch_offset and packed_batch when packing is enabled") dropout = self.dropout and self.training # only dropout for training - return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout, - my_batch_offset, packed_batch, self.opt, + return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout, + my_batch_offset, packed_batch, self.opt, self.fwd_tile_size, self.dropout_prob, self.mask_probe) class TransducerLoss(torch.nn.Module): """Transducer loss - Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural + Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural Networks Arguments: fuse_softmax_backward (bool, optional) whether to fuse the backward of transducer loss with softmax. (default: True) - opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized + opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized algorithm. In some cases, opt=1 might fall back to opt=0. (default: 1) - packed_input (bool, optional): whether to pack the output in a compact form with don't-care + packed_input (bool, optional): whether to pack the output in a compact form with don't-care data being removed. (default: False) """ def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=False): - super(TransducerLoss, self).__init__() + super(TransducerLoss, self).__init__() self.fuse_softmax_backward = fuse_softmax_backward self.opt = opt self.packed_input = packed_input self.dummy_batch_offset = torch.empty(0) - def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_len=None, + def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_len=None, debug_list=None): """Forward operation of transducer joint @@ -97,43 +97,43 @@ def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_le y_len (tensor): lengths of the labels for each batch. blank_idx (int): index for the null symbol. batch_offset (tensor, optional): tensor containing the offset of each batch - in the input. For example, batch offset can be obtained from: + in the input. For example, batch offset can be obtained from: batch_offset = torch.cumsum(f_len*(y_len+1), dim=0) - This argument is required if packed_input == True, and is ignored if + This argument is required if packed_input == True, and is ignored if packed_input == False. (default: None) max_f_len (int, optional): maximum length of the input in the time dimension. - For example, it can be obtained as + For example, it can be obtained as max_f_len = max(f_len) - This argument is required if packed_input == True, and is ignored if + This argument is required if packed_input == True, and is ignored if packed_input == False. (default: None) (default: None) - debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated - in the forward operation will be attached to this list for debug purpose. + debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated + in the forward operation will be attached to this list for debug purpose. (default: None) """ if self.packed_input: if batch_offset is None or max_f_len is None: raise Exception("Please specify batch_offset and max_f_len when packing is \ - enabled") + enabled") my_batch_offset = batch_offset my_max_f_len = max_f_len else: my_batch_offset = self.dummy_batch_offset my_max_f_len = x.size(1) - return TransducerLossFunc.apply(x, label, f_len, y_len, my_batch_offset, my_max_f_len, - blank_idx, self.fuse_softmax_backward, debug_list, + return TransducerLossFunc.apply(x, label, f_len, y_len, my_batch_offset, my_max_f_len, + blank_idx, self.fuse_softmax_backward, debug_list, self.opt, self.packed_input) class TransducerLossFunc(torch.autograd.Function): @staticmethod - def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx, + def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx, fuse_softmax_backward, debug_list, opt, packed_input): if fuse_softmax_backward == False: with torch.enable_grad(): x = torch.nn.functional.log_softmax(x, dim=-1) else: x = torch.nn.functional.log_softmax(x, dim=-1) - alpha, beta, loss = transducer_loss_cuda.forward( x, label, f_len, y_len, batch_offset, + alpha, beta, loss = transducer_loss_cuda.forward( x, label, f_len, y_len, batch_offset, max_f_len, blank_idx, opt, packed_input) if debug_list == []: debug_list += [alpha, beta] @@ -148,8 +148,8 @@ def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx, @staticmethod def backward(ctx, loss_grad): x, alpha, beta, f_len, y_len, label, batch_offset = ctx.saved_tensors - x_grad = transducer_loss_cuda.backward( x, loss_grad, alpha, beta, f_len, y_len, label, - batch_offset, ctx.max_f_len, ctx.blank_idx, ctx.opt, + x_grad = transducer_loss_cuda.backward( x, loss_grad, alpha, beta, f_len, y_len, label, + batch_offset, ctx.max_f_len, ctx.blank_idx, ctx.opt, ctx.fuse_softmax_backward, ctx.packed_input) if ctx.fuse_softmax_backward == False: x_grad = x.backward(x_grad) @@ -157,9 +157,9 @@ def backward(ctx, loss_grad): class TransducerJointFunc(torch.autograd.Function): @staticmethod - def forward(ctx, f, g, f_len, g_len, pack_output, relu, dropout, batch_offset, packed_batch, + def forward(ctx, f, g, f_len, g_len, pack_output, relu, dropout, batch_offset, packed_batch, opt, fwd_tile_size, dropout_prob, mask_probe): - h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt, + h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt, pack_output, relu, dropout, dropout_prob, fwd_tile_size) masked = relu or dropout if masked: @@ -185,8 +185,8 @@ def backward(ctx, loss_grad): f_len, g_len, batch_offset = ctx.saved_tensors inp = [loss_grad] - f_grad, g_grad = transducer_joint_cuda.backward( inp, f_len, g_len, batch_offset, - ctx.max_f_len, ctx.max_g_len, + f_grad, g_grad = transducer_joint_cuda.backward( inp, f_len, g_len, batch_offset, + ctx.max_f_len, ctx.max_g_len, ctx.pack_output, ctx.scale) return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None, None, \ diff --git a/apex/fp16_utils/README.md b/apex/fp16_utils/README.md index 941de1794..c423c8f85 100644 --- a/apex/fp16_utils/README.md +++ b/apex/fp16_utils/README.md @@ -9,7 +9,7 @@ fp16_optimizer.py contains `FP16_Optimizer`, a Python class designed to wrap an #### [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model) -fp16_util.py contains a number of utilities to manually manage master parameters and loss scaling, if the user chooses. +fp16_util.py contains a number of utilities to manually manage master parameters and loss scaling, if the user chooses. #### [Manual management documentation](https://nvidia.github.io/apex/fp16_utils.html#manual-master-parameter-management) diff --git a/apex/fp16_utils/fp16_optimizer.py b/apex/fp16_utils/fp16_optimizer.py index 7c0dd397f..92f6ba5ab 100755 --- a/apex/fp16_utils/fp16_optimizer.py +++ b/apex/fp16_utils/fp16_optimizer.py @@ -11,9 +11,9 @@ # TODO: Update overflow check + downscale to use Carl's fused kernel. class FP16_Optimizer(object): - def __init__(self, - init_optimizer, - static_loss_scale=1.0, + def __init__(self, + init_optimizer, + static_loss_scale=1.0, dynamic_loss_scale=False, dynamic_loss_args=None, verbose=True): @@ -52,7 +52,7 @@ def __init__(self, # Reset existing state dict key to the new master param. # We still need to recast per-param state tensors, if any, to FP32. if param in self.optimizer.state: - self.optimizer.state[master_param] = self.optimizer.state.pop(param) + self.optimizer.state[master_param] = self.optimizer.state.pop(param) elif param.type() == 'torch.cuda.FloatTensor': self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}" .format(param.size())) @@ -60,9 +60,9 @@ def __init__(self, param_group['params'][i] = param else: raise TypeError("Wrapped parameters must be either " - "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " + "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " "Received {}".format(param.type())) - + self.fp16_groups.append(fp16_params_this_group) self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group) @@ -110,7 +110,7 @@ def __init__(self, def maybe_print(self, msg): if self.verbose: print(msg) - + def __getstate__(self): raise RuntimeError("FP16_Optimizer should be serialized using state_dict().") @@ -229,9 +229,9 @@ def state_dict(self): def load_state_dict(self, state_dict): """ - Loads a state_dict created by an earlier call to state_dict(). - If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, - whose parameters in turn came from ``model``, it is expected that the user + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user will call ``model.load_state_dict()`` before ``fp16_optimizer_instance.load_state_dict()`` is called. @@ -252,17 +252,17 @@ def load_state_dict(self, state_dict): self.first_closure_call_this_step = state_dict['first_closure_call_this_step'] self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) # At this point, the optimizer's references to the model's fp32 parameters are up to date. - # The optimizer's hyperparameters and internal buffers are also up to date. + # The optimizer's hyperparameters and internal buffers are also up to date. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still - # out of date. There are two options. - # 1: Refresh the master params from the model's fp16 params. + # out of date. There are two options. + # 1: Refresh the master params from the model's fp16 params. # This requires less storage but incurs precision loss. # 2: Save and restore the fp32 master copies separately. # We choose option 2. - # - # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device - # of their associated parameters, because it's possible those buffers might not exist yet in - # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been + # + # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device + # of their associated parameters, because it's possible those buffers might not exist yet in + # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been # constructed in the same way as the one whose state_dict we are loading, the same master params # are guaranteed to exist, so we can just copy_() from the saved master params. for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']): @@ -271,14 +271,14 @@ def load_state_dict(self, state_dict): def step(self, closure=None): # could add clip option. """ - If no closure is supplied, :attr:`step` should be called after + If no closure is supplied, :attr:`step` should be called after ``fp16_optimizer_obj.backward(loss)``. :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run another forward pass using their model. - If a closure is supplied, :attr:`step` may be called without a prior call to + If a closure is supplied, :attr:`step` may be called without a prior call to :attr:`backward(loss)`. This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. However, the user should take care that any ``loss.backward()`` call within the closure @@ -289,7 +289,7 @@ def step(self, closure=None): # could add clip option. Example with closure:: - # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an + # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an # existing pytorch optimizer. for input, target in dataset: def closure(): @@ -319,7 +319,7 @@ def closure(): maybe_print("Gradient overflow. Skipping step, reducing " + "loss scale to {}".format(self.loss_scaler.loss_scale())) return - + if closure is not None: retval = self._step_with_closure(closure) else: @@ -343,7 +343,7 @@ def wrapped_closure(): self.first_closure_call_this_step = False else: # If self.optimizer.step() internally calls wrapped_closure more than once, - # it may update the fp32 params after each call. However, self.optimizer + # it may update the fp32 params after each call. However, self.optimizer # doesn't know about the fp16 params at all. If the fp32 params get updated, # we can't rely on self.optimizer to refresh the fp16 params. We need # to handle that manually: @@ -351,11 +351,11 @@ def wrapped_closure(): # Our API expects the user to give us ownership of the backward() call by # replacing all calls to loss.backward() with optimizer.backward(loss). # This requirement holds whether or not the call to backward() is made within a closure. - # If the user is properly calling optimizer.backward(loss) within "closure," + # If the user is properly calling optimizer.backward(loss) within "closure," # calling closure() here will give the fp32 master params fresh gradients - # for the optimizer to play with, so all wrapped_closure needs to do is call + # for the optimizer to play with, so all wrapped_closure needs to do is call # closure() and return the loss. - temp_loss = closure() + temp_loss = closure() while(self.overflow): scale = self.loss_scaler.loss_scale() # self._update_scale(self.overflow) # now done at the end of backward @@ -371,7 +371,7 @@ def wrapped_closure(): return retval def backward(self, loss, update_master_grads=True, retain_graph=False): - """ + """ :attr:`backward` performs the following conceptual steps: 1. fp32_loss = loss.float() (see first Note below) @@ -385,19 +385,19 @@ def backward(self, loss, update_master_grads=True, retain_graph=False): .. note:: :attr:`backward` internally converts the loss to fp32 before applying the loss scale. - This provides some additional safety against overflow if the user has supplied an - fp16 loss value. + This provides some additional safety against overflow if the user has supplied an + fp16 loss value. However, for maximum overflow safety, the user should - compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to + compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to :attr:`backward`. .. warning:: - The gradients found in a model's leaves after the call to - :attr:`backward` should not be regarded as valid in general, - because it's possible - they have been scaled (and in the case of dynamic loss scaling, - the scale factor may change over time). - If the user wants to inspect gradients after a call to :attr:`backward`, + The gradients found in a model's leaves after the call to + :attr:`backward` should not be regarded as valid in general, + because it's possible + they have been scaled (and in the case of dynamic loss scaling, + the scale factor may change over time). + If the user wants to inspect gradients after a call to :attr:`backward`, only the master gradients should be regarded as valid. These can be retrieved via :attr:`inspect_master_grad_data()`. @@ -412,22 +412,22 @@ def backward(self, loss, update_master_grads=True, retain_graph=False): optimizer.backward(loss) # Naive operation with multiple losses (technically valid, but less efficient): - # fp32 grads will be correct after the second call, but + # fp32 grads will be correct after the second call, but # the first call incurs an unnecessary fp16->fp32 grad copy. optimizer.backward(loss1) optimizer.backward(loss2) # More efficient way to handle multiple losses: - # The fp16->fp32 grad copy is delayed until fp16 grads from all + # The fp16->fp32 grad copy is delayed until fp16 grads from all # losses have been accumulated. optimizer.backward(loss1, update_master_grads=False) optimizer.backward(loss2, update_master_grads=False) optimizer.update_master_grads() - """ - # To consider: try multiple backward passes using retain_grad=True to find + """ + # To consider: try multiple backward passes using retain_grad=True to find # a loss scale that works. After you find a loss scale that works, do a final dummy - # backward pass with retain_graph=False to tear down the graph. Doing this would avoid - # discarding the iteration, but probably wouldn't improve overall efficiency. + # backward pass with retain_graph=False to tear down the graph. Doing this would avoid + # discarding the iteration, but probably wouldn't improve overall efficiency. scaled_loss = loss.float()*self.loss_scaler.loss_scale() scaled_loss.backward(retain_graph=retain_graph) if update_master_grads: @@ -436,8 +436,8 @@ def backward(self, loss, update_master_grads=True, retain_graph=False): def update_master_grads(self): # torch.cuda.nvtx.range_push("update_master_grads") """ - Copy the ``.grad`` attribute from stored references to fp16 parameters to - the ``.grad`` attribute of the fp32 master parameters that are directly + Copy the ``.grad`` attribute from stored references to fp16 parameters to + the ``.grad`` attribute of the fp32 master parameters that are directly updated by the optimizer. :attr:`update_master_grads` only needs to be called if ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. """ @@ -492,19 +492,19 @@ def update_master_grads(self): def inspect_master_grad_data(self): """ - When running with :class:`FP16_Optimizer`, + When running with :class:`FP16_Optimizer`, ``.grad`` attributes of a model's fp16 leaves should not be - regarded as truthful, because they might be scaled. + regarded as truthful, because they might be scaled. After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, the fp32 master params' ``.grad`` - attributes will contain valid gradients properly divided by the loss scale. However, - because :class:`FP16_Optimizer` flattens some parameters, accessing them may be + attributes will contain valid gradients properly divided by the loss scale. However, + because :class:`FP16_Optimizer` flattens some parameters, accessing them may be nonintuitive. :attr:`inspect_master_grad_data` allows those gradients to be viewed with shapes corresponding to their associated model leaves. Returns: List of lists (one list for each parameter group). The list for each parameter group - is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. + is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. """ if self.overflow: print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. " diff --git a/apex/fp16_utils/fp16util.py b/apex/fp16_utils/fp16util.py index dcdc3447a..7d2f8251a 100644 --- a/apex/fp16_utils/fp16util.py +++ b/apex/fp16_utils/fp16util.py @@ -135,7 +135,7 @@ def prep_param_lists(model, flat_master=False): def model_grads_to_master_grads(model_params, master_params, flat_master=False): """ - Copy model gradients to master gradients. + Copy model gradients to master gradients. Args: model_params: List of model parameters created by :func:`prep_param_lists`. @@ -164,7 +164,7 @@ def master_params_to_model_params(model_params, master_params, flat_master=False master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. """ if flat_master: - for model, master in zip(model_params, + for model, master in zip(model_params, _unflatten_dense_tensors(master_params[0].data, model_params)): model.data.copy_(master) else: diff --git a/apex/fp16_utils/loss_scaler.py b/apex/fp16_utils/loss_scaler.py index b9f32fe01..06158ecbe 100644 --- a/apex/fp16_utils/loss_scaler.py +++ b/apex/fp16_utils/loss_scaler.py @@ -12,7 +12,7 @@ class LossScaler: Class that manages a static loss scale. This class is intended to interact with :class:`FP16_Optimizer`, and should not be directly manipulated by the user. - Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to + Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to :class:`FP16_Optimizer`'s constructor. Args: @@ -47,7 +47,7 @@ def backward(self, loss, retain_graph=False): class DynamicLossScaler: """ Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` - indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of + indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` operates, because the default options can be changed using the the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. @@ -55,18 +55,18 @@ class DynamicLossScaler: Loss scaling is designed to combat the problem of underflowing gradients encountered at long times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are - encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has + encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has occurred. :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, - and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. + and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. If a certain number of iterations occur without overflowing gradients detected, :class:`DynamicLossScaler` increases the loss scale once more. - In this way :class:`DynamicLossScaler` attempts to "ride the edge" of + In this way :class:`DynamicLossScaler` attempts to "ride the edge" of always using the highest loss scale possible without incurring overflow. Args: init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` - scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. + scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. """ @@ -91,8 +91,8 @@ def has_overflow(self, params): # `x` is a torch.Tensor def _has_inf_or_nan(x): try: - # if x is half, the .float() incurs an additional deep copy, but it's necessary if - # Pytorch's .sum() creates a one-element tensor of the same type as x + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x # (which is true for some recent version of pytorch). cpu_sum = float(x.float().sum()) # More efficient version that can be used if .sum() returns a Python scalar @@ -130,8 +130,8 @@ def scale_gradient(self, module, grad_in, grad_out): def backward(self, loss, retain_graph=False): scaled_loss = loss*self.loss_scale scaled_loss.backward(retain_graph=retain_graph) - -############################################################## + +############################################################## # Example usage below here -- assuming it's in a separate file ############################################################## """ @@ -167,10 +167,10 @@ def backward(self, loss, retain_graph=False): # Run backprop optimizer.zero_grad() loss.backward() - + # Check for overflow has_overflow = DynamicLossScaler.has_overflow(parameters) - + # If no overflow, unscale grad and update as usual if not has_overflow: for param in parameters: diff --git a/apex/fused_dense/fused_dense.py b/apex/fused_dense/fused_dense.py index 97377a423..82b24bcde 100644 --- a/apex/fused_dense/fused_dense.py +++ b/apex/fused_dense/fused_dense.py @@ -2,7 +2,7 @@ from torch import nn import fused_dense_cuda from apex._autocast_utils import _cast_if_autocast_enabled -import math +import math #implements fused GEMM+bias in forward pass using mlp_cuda from apex class FusedDenseFunc(torch.autograd.Function): @@ -87,7 +87,7 @@ def forward(self, input): return fused_dense_function(input, self.weight, self.bias) else: return dense_no_bias_function(input, self.weight) - + def reset_parameters(self): nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) @@ -95,15 +95,15 @@ def reset_parameters(self): fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound) - -#======================================================================================= -# -#======================================================================================= + +#======================================================================================= +# +#======================================================================================= class FusedDenseGeluDense(nn.Module): ''' https://zeta.apac.ai/en/latest/zeta/nn/modules/fused_gelu_dense/ module combines dense layers with GELU activations in a single neural network layer. - layer consists of two dense sub-layers, each followed by a GELU activation function. + layer consists of two dense sub-layers, each followed by a GELU activation function. It takes an input tensor and passes it through these sub-layers to produce the final output. Parameters: dim (int): Input dimension. diff --git a/apex/optimizers/fused_lars.py b/apex/optimizers/fused_lars.py index 3e60b2cce..542f26af0 100644 --- a/apex/optimizers/fused_lars.py +++ b/apex/optimizers/fused_lars.py @@ -40,7 +40,7 @@ def __init__(self, params, lr=required, momentum=0, dampening=0, self._dummy_overflow_buf = torch.cuda.IntTensor(1).zero_() else: raise RuntimeError('apex.optimizers.FusedLARS requires cuda extensions') - + def __setstate__(self, state): super(FusedLARS, self).__setstate__(state) for group in self.param_groups: @@ -97,7 +97,7 @@ def step(self, closure=None): nesterov = group['nesterov'] lr = group['lr'] is_skipped = group['is_skipped'] - + # For each group, there are 3 possible combinations we need to consider: # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy # 1. fp16, fp16, fp16, No diff --git a/apex/optimizers/fused_mixed_precision_lamb.py b/apex/optimizers/fused_mixed_precision_lamb.py index 7ecda4f51..f44f1e20f 100644 --- a/apex/optimizers/fused_mixed_precision_lamb.py +++ b/apex/optimizers/fused_mixed_precision_lamb.py @@ -14,7 +14,7 @@ def __init__(self, params, lr=1e-3, step=0, bias_correction=True, reduced_precision_dtype=None): if amsgrad: raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') - + # The learning rate (lr) and optimizer step (step) should be located on device # in order to faciliated device sync free execution defaults = dict(lr=torch.tensor(lr, dtype=torch.float32), @@ -44,7 +44,7 @@ def __init__(self, params, lr=1e-3, step=0, bias_correction=True, # Mixed Precision support self.reduced_precision_dtype = reduced_precision_dtype self.param_groups_full_precision = [] - + self._step_supports_amp_scaling = True self.adam_w_mode = 1 if adam_w_mode else 0 self.use_nvlamb = use_nvlamb @@ -125,9 +125,9 @@ def _setup_full_precision_params(self): for p in param_list ], }) - + # add_param_groups() is overridden because default items can be tensors. The - # parent version does not clone the default item, so two param groups can + # parent version does not clone the default item, so two param groups can # accidentally point to the same default item value where they can differ # given they are in separate groups. def add_param_group(self, param_group): @@ -159,7 +159,7 @@ def step(self, closure=None, grad_scaler=None): if p.grad is None: continue grad_list.append(p.grad) - + # Overflow check of gradients device = self.param_groups[0]["params"][0].device found_inf = ( @@ -176,7 +176,7 @@ def step(self, closure=None, grad_scaler=None): else: scale = torch.ones((1,), device=device) inv_scale = torch.ones((1,), device=device) - + # grad_norm is of scaled gradients. # So, multiply `max_grad_norm` by scale. max_grad_norm = self.defaults['max_grad_norm'] * scale diff --git a/apex/parallel/LARC.py b/apex/parallel/LARC.py index 4a93fcd65..23bab53fd 100644 --- a/apex/parallel/LARC.py +++ b/apex/parallel/LARC.py @@ -5,10 +5,10 @@ class LARC(object): """ :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, - in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive + in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive local learning rate for each individual parameter. The algorithm is designed to improve convergence of large batch training. - + See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. In practice it modifies the gradients of parameters as a proxy for modifying the learning rate @@ -62,7 +62,7 @@ def param_groups(self): @param_groups.setter def param_groups(self, value): self.optim.param_groups = value - + def state_dict(self): return self.optim.state_dict() diff --git a/apex/transformer/functional/fused_rope.py b/apex/transformer/functional/fused_rope.py index e74906151..ac0c0b1ce 100644 --- a/apex/transformer/functional/fused_rope.py +++ b/apex/transformer/functional/fused_rope.py @@ -116,7 +116,7 @@ def forward( h = t.shape[2] d = t.shape[3] # t is of shape [s, b, h, d] - # freqs is of shape [s, 1, 1, d] + # freqs is of shape [s, 1, 1, d] act_options = {'dtype': t.dtype, 'device': t.device, 'requires_grad': False} if transpose_output_memory: @@ -130,18 +130,18 @@ def forward( return output - + @staticmethod def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: - (freqs,) = ctx.saved_tensors + (freqs,) = ctx.saved_tensors s = grad_output.shape[0] b = grad_output.shape[1] h = grad_output.shape[2] d = grad_output.shape[3] - + act_options = {'dtype': grad_output.dtype, 'device': grad_output.device, 'requires_grad': False} if ctx.transpose_output_memory: grad_input = torch.empty((b, s, h, d), **act_options).transpose(0, 1) @@ -195,7 +195,7 @@ def forward( transpose_output_memory: bool = False, ) -> torch.Tensor: raise ValueError("Invalid forward implementation.") - + @staticmethod def backward( ctx, grad_output: torch.Tensor @@ -228,8 +228,8 @@ def backward( grad_output, cos_, sin_, ctx.transpose_output_memory ) return grad_input, None, None, None - -class FusedRoPECachedFuncAiter(FusedRoPECachedFunc): + +class FusedRoPECachedFuncAiter(FusedRoPECachedFunc): @staticmethod def forward( ctx, @@ -243,7 +243,7 @@ def forward( h = t.shape[2] d = t.shape[3] # t is of shape [s, b, h, d] - # freqs is of shape [s, 1, 1, d] + # freqs is of shape [s, 1, 1, d] act_options = {'dtype': t.dtype, 'device': t.device, 'requires_grad': False} if transpose_output_memory: @@ -267,7 +267,7 @@ def backward( b = grad_output.shape[1] h = grad_output.shape[2] d = grad_output.shape[3] - + act_options = {'dtype': grad_output.dtype, 'device': grad_output.device, 'requires_grad': False} if ctx.transpose_output_memory: grad_input = torch.empty((b, s, h, d), **act_options).transpose(0, 1) @@ -320,7 +320,7 @@ def forward( freqs: torch.Tensor, ) -> torch.Tensor: raise ValueError("Invalid forward implementation.") - + @staticmethod def backward( ctx, grad_output: torch.Tensor @@ -383,7 +383,7 @@ def backward( h = grad_output.shape[1] d = grad_output.shape[2] # t is of shape [t, h, d] - + act_options = {'dtype': grad_output.dtype, 'device': grad_output.device, 'requires_grad': False} grad_input = torch.empty((t, h, d), **act_options) aiter.rope_thd_bwd_impl(grad_input, grad_output, cu_seqlens, freqs, 0, False, False) @@ -488,7 +488,7 @@ def forward( cos_w: torch.Tensor, sin_w: torch.Tensor, ) -> torch.Tensor: - + s = t.shape[0] h = t.shape[2] d = t.shape[3] @@ -509,12 +509,12 @@ def backward( ) -> Tuple[Union[torch.Tensor, None], ...]: cos_h, sin_h, cos_w, sin_w = ctx.saved_tensors - + s = grad_output.shape[0] h = grad_output.shape[2] d = grad_output.shape[3] # t is of shape [s, ih* iw, h, d] - + act_options = {'dtype': grad_output.dtype, 'device': grad_output.device, 'requires_grad': False} grad_input = torch.empty((s, ctx.img_h * ctx.img_w, h, d), **act_options) aiter.rope_2d_bwd_impl(grad_input, grad_output, cos_h, sin_h, cos_w, sin_w, ctx.img_h, ctx.img_w, 0, False, False) diff --git a/compatibility/fused_layer_norm_cuda.py b/compatibility/fused_layer_norm_cuda.py index 2722e0252..8252203cc 100644 --- a/compatibility/fused_layer_norm_cuda.py +++ b/compatibility/fused_layer_norm_cuda.py @@ -22,11 +22,11 @@ def _load_module(self): finally: self._loading = False return self._loaded_module - + def __getattr__(self, name): if name.startswith("_"): raise AttributeError(f"module fused_layer_norm_cuda has no attribute '{name}'") - + module = self._load_module() return getattr(module, name) @@ -36,9 +36,9 @@ def __dir__(self): return dir(module) except: return [] - + def __repr__(self): return "" - + #replace module with lazy loader sys.modules[__name__] = _FusedLayerCudaModule() \ No newline at end of file diff --git a/compatibility/mlp_cuda.py b/compatibility/mlp_cuda.py index 4c873d560..9375a6e67 100644 --- a/compatibility/mlp_cuda.py +++ b/compatibility/mlp_cuda.py @@ -22,11 +22,11 @@ def _load_module(self): finally: self._loading = False return self._loaded_module - + def __getattr__(self, name): if name.startswith("_"): raise AttributeError(f"module mlp_cuda has no attribute '{name}'") - + module = self._load_module() return getattr(module, name) @@ -36,9 +36,9 @@ def __dir__(self): return dir(module) except: return [] - + def __repr__(self): return "" - + #replace module with lazy loader sys.modules[__name__] = _MLPCudaModule() \ No newline at end of file diff --git a/csrc/fused_dense_cuda.cu b/csrc/fused_dense_cuda.cu index 15c076f68..0dc854e37 100644 --- a/csrc/fused_dense_cuda.cu +++ b/csrc/fused_dense_cuda.cu @@ -159,7 +159,7 @@ int gemm_lt( * 1. Set the Data type of matrix elements. * 3. Set the layout: Size/shape of the matrix. This depends if transpose is needed or not. * 4. Set the leading dimentions - * + * */ hipblasLtMatrixLayout_t matA = nullptr, matB = nullptr, matC = nullptr; @@ -189,10 +189,10 @@ int gemm_lt( ldb = k; } // NN - CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype_a, trans_a == HIPBLAS_OP_T ? k : m, + CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype_a, trans_a == HIPBLAS_OP_T ? k : m, trans_a == HIPBLAS_OP_T ? m : k, lda)); - CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype_b, trans_b == HIPBLAS_OP_T ? n : k, + CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype_b, trans_b == HIPBLAS_OP_T ? n : k, trans_b == HIPBLAS_OP_T ? k : n, ldb)); CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype_c, m, n, m)); @@ -200,7 +200,7 @@ int gemm_lt( /* ============================================================================================ * Matmul desc: * 1. Create operation descriptor with compute data type - * 2. Set transpose operation + * 2. Set transpose operation */ hipblasLtMatmulDesc_t matmulDesc = nullptr; @@ -215,16 +215,16 @@ int gemm_lt( CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmulDesc, desc_computeType, desc_dataType)); - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a))); - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b))); /* ============================================================================================ * Configure epilogue - * 1. Set mat-mul post-ops: bias, bgradb, gelu. - * 2. + * 1. Set mat-mul post-ops: bias, bgradb, gelu. + * 2. */ hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT; @@ -246,16 +246,16 @@ int gemm_lt( { epilogue = HIPBLASLT_EPILOGUE_GELU_BIAS; } - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &d_bias, sizeof(d_bias))); - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &dtype_bias, sizeof(dtype_bias))); - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &d_gelu, sizeof(d_gelu))); - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelu, sizeof(ld_gelu))); - // CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, + // CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, // &dtype_gelu, sizeof(dtype_gelu))); } else if (use_bias) @@ -268,10 +268,10 @@ int gemm_lt( { epilogue = HIPBLASLT_EPILOGUE_BIAS; } - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &d_bias, sizeof(d_bias))); - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &dtype_bias, sizeof(dtype_bias))); } else if (use_gelu) @@ -284,15 +284,15 @@ int gemm_lt( { epilogue = HIPBLASLT_EPILOGUE_GELU_AUX; } - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &d_gelu, sizeof(d_gelu))); - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelu, sizeof(ld_gelu))); - // CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, + // CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, // &dtype_gelu, sizeof(dtype_gelu))); } - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); /* ============================================================================================ @@ -310,8 +310,8 @@ int gemm_lt( hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceCreate(&pref)); - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmulDesc, matA, matB, matC, matC, - pref, request_solutions, heuristicResult, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmulDesc, matA, matB, matC, matC, + pref, request_solutions, heuristicResult, &returnedAlgoCount)); if (returnedAlgoCount == 0) @@ -326,7 +326,7 @@ int gemm_lt( } hipMalloc(&workspace, workspace_size); - CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceSetAttribute(pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceSetAttribute(pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace, sizeof(workspace_size))); /* ============================================================================================ @@ -337,14 +337,14 @@ int gemm_lt( void *d_c = static_cast(C.data_ptr()); CHECK_HIPBLASLT_ERROR(hipblasLtMatmul(handle, matmulDesc, alpha, d_a, matA, - d_b, matB, beta, static_cast(d_c), + d_b, matB, beta, static_cast(d_c), matC, d_c, matC, &heuristicResult[0].algo, workspace, workspace_size, stream)); #if DEBUG - std::cout << "\nTensor-A:\n" << A - << "\nTensor-B:\n" << B - << "\nTensor-C:\n" << C + std::cout << "\nTensor-A:\n" << A + << "\nTensor-B:\n" << B + << "\nTensor-C:\n" << C << "\nTensor-Bias:\n" << bias << std::endl; std::cout << "\nSizes: A[" << A.size(0) << "," << A.size(1) << "]" << std::endl; std::cout << "\nSizes: B[" << B.size(0) << "," << B.size(1) << "]" << std::endl; @@ -376,7 +376,7 @@ hipblasStatus_t gemm_bias( hipblasOperation_t transa, hipblasOperation_t transb, #if DEBUG std::cout << "gemm_bias " << std::endl; #endif - return hipblasGemmEx(handle, transa, transb, m, n, k, alpha, A, DataType, lda, B, DataType, + return hipblasGemmEx(handle, transa, transb, m, n, k, alpha, A, DataType, lda, B, DataType, ldb, beta, C, DataType, ldc, ComputeType, CUBLAS_GEMM_DEFAULT); } @@ -394,7 +394,7 @@ at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor b at::Tensor dummy_gelu = at::empty({0}, torch::device(torch::kCUDA).dtype(input.scalar_type())); - // ********************************************************************************** + // ********************************************************************************** // output[batch_size, out_features] = input[batch_size, in_features] * weight[out_features,in_features] + bias[out_features] // ********************************************************************************** auto output = at::zeros({batch_size, out_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); @@ -406,7 +406,7 @@ at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor b } else { DISPATCH_TYPES(input.scalar_type(), "linear_bias_forward", [&] { auto result = gemm_bias( - HIPBLAS_OP_T, HIPBLAS_OP_N, out_features, batch_size, in_features, + HIPBLAS_OP_T, HIPBLAS_OP_N, out_features, batch_size, in_features, &alpha, &beta, weight.data_ptr(), input.data_ptr(), output.data_ptr()); if (result != 0) { fprintf(stderr, "INVALID RESULT for linear_bias_forward\n"); } }); @@ -418,7 +418,7 @@ at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor b /**************************************************************************** * In the backward pass, we compute the gradients of the loss with respect to input, weight, and bias. * The key matrix operations are: - * 1. Gradient of Input : grad_input[batch_size, in_features] = output[batch_size, out_features] * weight[out_features,in_features] + * 1. Gradient of Input : grad_input[batch_size, in_features] = output[batch_size, out_features] * weight[out_features,in_features] * 2. Gradient of Weights: grad_weight[out_features,in_features] = input[batch_size, in_features] * output[batch_size, out_features] * 3. Gradient of Bias : grad_bias=sum(dY) **************************************************************************/ @@ -434,7 +434,7 @@ std::vector linear_bias_backward(at::Tensor input, at::Tensor weight auto dummy_gelu = at::empty({0}, torch::device(torch::kCUDA).dtype(input.scalar_type())); auto grad_weight = at::zeros({out_features,in_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); auto grad_input = at::zeros({batch_size, in_features}, torch::device(torch::kCUDA).dtype(input.scalar_type())); - + #if DEBUG std::cout << "linear_bias_backward " << std::endl; #endif @@ -447,7 +447,7 @@ std::vector linear_bias_backward(at::Tensor input, at::Tensor weight // ********************************************************************************** // Gradient of Weights: - // grad_weight[out_features,in_features] = input[batch_size, in_features](T) * output[batch_size, out_features] + // grad_weight[out_features,in_features] = input[batch_size, in_features](T) * output[batch_size, out_features] // ********************************************************************************** CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_N, HIPBLAS_OP_T, &alpha, &beta, input, output, grad_weight, grad_bias, dummy_gelu, true, false, false)); @@ -459,7 +459,7 @@ std::vector linear_bias_backward(at::Tensor input, at::Tensor weight } else { DISPATCH_TYPES(input.scalar_type(), "linear_bias_forward", [&] { auto result = gemm_bias( - HIPBLAS_OP_N, HIPBLAS_OP_T, in_features, out_features, batch_size, + HIPBLAS_OP_N, HIPBLAS_OP_T, in_features, out_features, batch_size, &alpha, &beta, input.data_ptr(), output.data_ptr(), grad_weight.data_ptr()); if (result != 0) { fprintf(stderr, "INVALID RESULT for linear_bias_forward\n"); } }); @@ -480,11 +480,11 @@ std::vector linear_bias_backward(at::Tensor input, at::Tensor weight * [GELU] https://pytorch.org/docs/stable/generated/torch.nn.GELU.html * * module combines dense layers with GELU activations in a single neural network layer. - * layer consists of two dense sub-layers, each followed by a GELU activation function. + * layer consists of two dense sub-layers, each followed by a GELU activation function. * It takes an input tensor and passes it through these sub-layers to produce the final output. * * layer consists of the following internal layers: - * dense1: The first dense layer. + * dense1: The first dense layer. * output[batch_size, hidden_features] = input[batch_size, in_features] * weight[hidden_features,in_features] + bias[hidden_features] * activation: The GELU(Gaussian Error Linear Units) activation function. * dense2: The second dense layer. @@ -493,16 +493,16 @@ std::vector linear_bias_backward(at::Tensor input, at::Tensor weight * input (torch.Tensor): (∗,Hin ) where ∗ is batch_size and Hin=in_features * weight (torch.Tensor): the learnable weights of the module of shape(out_features,in_features). * bias (torch.Tensor): the learnable bias of the module of shape(out_features) - * + * * Output: (*,Hout ) where all but the last dimension are the same shape as the input and Hout = out_features. * **************************************************************************/ -std::vector linear_gelu_linear_forward(at::Tensor input, at::Tensor weight, at::Tensor bias, +std::vector linear_gelu_linear_forward(at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor weight2, at::Tensor bias2) { const float alpha = 1.0, beta = 0.0; - int64_t batch_size = input.size(0); // input[batch_size, in_features] + int64_t batch_size = input.size(0); // input[batch_size, in_features] int64_t in_features = input.size(1); // bias[hidden_features] and bias2[out_features] int64_t hidden_features = weight.size(0); // weight[hidden_features, in_features] int64_t out_features = weight2.size(0); // weight2[out_features, hidden_features] @@ -538,16 +538,16 @@ std::vector linear_gelu_linear_forward(at::Tensor input, at::Tenso * The key matrix operations are: * For second gemm * 1. Gradient of Input (dX): grad_output[batch_size, hidden_features] = output2[batch_size,out_features] ⋅ weight2[out_features, hidden_features] - * 2. Gradient of Weights (dW): grad_weight[hidden_features, in_features] = output[batch_size, hidden_features](T) ⋅ output2[batch_size,out_features] + * 2. Gradient of Weights (dW): grad_weight[hidden_features, in_features] = output[batch_size, hidden_features](T) ⋅ output2[batch_size,out_features] * For First gemm * 1. Gradient of Input (dX): grad_input[batch_size, in_features] = output[batch_size, hidden_features] ⋅ weight[hidden_features,in_features](T) - * 2. Gradient of Weights (dW): grad_weight[hidden_features, in_features] = input[batch_size, in_features](T) ⋅ output[batch_size, hidden_features] + * 2. Gradient of Weights (dW): grad_weight[hidden_features, in_features] = input[batch_size, in_features](T) ⋅ output[batch_size, hidden_features] **************************************************************************/ std::vector linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu, at::Tensor output, at::Tensor weight, at::Tensor weight2, at::Tensor output2) { const float alpha = 1.0, beta = 0.0; - + int64_t batch_size = input.size(0); int64_t in_features = input.size(1); int64_t hidden_features = weight.size(0); @@ -573,7 +573,7 @@ std::vector linear_gelu_linear_backward(at::Tensor input, at::Tensor // ********************************************************************************** // Gradient For second gemm : // grad_output[batch_size, hidden_features] = output2[batch_size,out_features] ⋅ weight2[out_features, hidden_features] - // grad_weight[out_features,in_features] = input[batch_size, in_features](T) * output[batch_size, out_features] + // grad_weight[out_features,in_features] = input[batch_size, in_features](T) * output[batch_size, out_features] // ********************************************************************************** CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_N, HIPBLAS_OP_N, &alpha, &beta, weight2, output2, grad_output, grad_bias2, dummy_gelu, false, false, false)); CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_N, HIPBLAS_OP_T, &alpha, &beta, output2, output, grad_weight2, grad_bias2, dummy_gelu, true, false, false)); @@ -582,7 +582,7 @@ std::vector linear_gelu_linear_backward(at::Tensor input, at::Tensor // ********************************************************************************** // Gradient For First gemm : // grad_input [batch_size, in_features] = output[batch_size, out_features] * Weight[out_features,in_features] - // grad_weight[out_features,in_features] = input[batch_size, in_features](T) * output[batch_size, out_features] + // grad_weight[out_features,in_features] = input[batch_size, in_features](T) * output[batch_size, out_features] // ********************************************************************************** CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_N, HIPBLAS_OP_N, &alpha, &beta, weight, output, grad_input, grad_bias2, dummy_gelu, false, false, false)); CHECK_HIPBLASLT_ERROR(gemm_lt(HIPBLAS_OP_N, HIPBLAS_OP_T, &alpha, &beta, output, input, grad_weight, grad_bias2, dummy_gelu, true, false, false)); diff --git a/csrc/layer_norm_cuda_kernel.cu b/csrc/layer_norm_cuda_kernel.cu index 706ec8162..efd0d6793 100644 --- a/csrc/layer_norm_cuda_kernel.cu +++ b/csrc/layer_norm_cuda_kernel.cu @@ -74,7 +74,7 @@ void cuWelfordMuSigma2( const int i1, U& mu, U& sigma2, - U* buf, + U* buf, const int GPU_WARP_SIZE, bool rms_only) { @@ -107,7 +107,7 @@ void cuWelfordMuSigma2( } } - + for (; l < n2; ++l) { U curr = static_cast(lvals[l]); @@ -121,11 +121,11 @@ void cuWelfordMuSigma2( // intra-warp reductions if(USE_ROCM){ #pragma unroll - for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { + for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { U sigma2B = WARP_SHFL_DOWN(sigma2, stride); if (!rms_only) { U muB = WARP_SHFL_DOWN(mu, stride); - U countB = WARP_SHFL_DOWN(count, stride); + U countB = WARP_SHFL_DOWN(count, stride); cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); } else { cuChanRMSOnlineSum(sigma2B, sigma2); @@ -985,14 +985,14 @@ void HostApplyLayerNorm( // Optimization for ROCm MI100 threads.y = 1; #endif - + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); int nshared = threads.y > 1 ? threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : 0; - + cuApplyLayerNorm<<>>( output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size); } @@ -1220,7 +1220,7 @@ void HostRMSNormGradient( epsilon, true); }); - + const dim3 threads3(32,8,1); const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); diff --git a/csrc/megatron/fused_bias_swiglu_cuda.cu b/csrc/megatron/fused_bias_swiglu_cuda.cu index 6f5e54961..bee7679f9 100644 --- a/csrc/megatron/fused_bias_swiglu_cuda.cu +++ b/csrc/megatron/fused_bias_swiglu_cuda.cu @@ -8,10 +8,10 @@ __device__ __forceinline__ float silu(float x) { // CUDA kernel for Fused Bias SwiGLU with chunking template -__global__ void fused_bias_swiglu_kernel(const T* __restrict__ input, - const T* __restrict__ bias, - T* __restrict__ output, - int half_dim, +__global__ void fused_bias_swiglu_kernel(const T* __restrict__ input, + const T* __restrict__ bias, + T* __restrict__ output, + int half_dim, int max_index) { int output_idx = blockIdx.x * blockDim.x + threadIdx.x; int row_idx = output_idx / half_dim; @@ -31,10 +31,10 @@ __global__ void fused_bias_swiglu_kernel(const T* __restrict__ input, // CUDA Kernel: Computes the backward pass for fused bias SwiGLU template __global__ void fused_bias_swiglu_backward_kernel( - const T* __restrict__ grad_output, - const T* __restrict__ input, - const T* __restrict__ bias, - T* __restrict__ grad_input, + const T* __restrict__ grad_output, + const T* __restrict__ input, + const T* __restrict__ bias, + T* __restrict__ grad_input, int half_dim, int max_index) { int output_idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -80,7 +80,7 @@ torch::Tensor fused_bias_swiglu_forward(torch::Tensor input, torch::Tensor bias) int threads = prop.maxThreadsPerBlock; int blocks = (batch_size * half_dim + threads - 1) / threads; blocks = min(blocks, prop.maxGridSize[0]); - + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fused_bias_swiglu_forward", [&] { fused_bias_swiglu_kernel<<>>( diff --git a/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu b/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu index 24e5f0294..a2c7bf62a 100644 --- a/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu +++ b/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu @@ -54,7 +54,7 @@ void gemmex_wrapper_fp16( at::BFloat16* D, void* d_workspace, int64_t max_workspace_size, - cudaStream_t stream) + cudaStream_t stream) { cublasLtMatrixLayout_t matA, matB, matC, matD; CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matA, CUDA_R_16BF, m, k, m)); @@ -148,7 +148,7 @@ void gemmex_wrapper_fp16( at::Half* D, void* d_workspace, int64_t max_workspace_size, - cudaStream_t stream) + cudaStream_t stream) { cublasLtMatrixLayout_t matA, matB, matC, matD; CHECK_CUBLASLT_ERROR(cublasLtMatrixLayoutCreate(&matA, CUDA_R_16F, m, k, m)); @@ -249,14 +249,14 @@ void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight,int in_dim, i batch_count, alpha, beta, - input, //da + input, //da d_output, //db d_weight, //dc d_weight, //dd d_workspace, max_workspace_size, stream); -} +} template void wgrad_gemm_accum_fp16_cuda(at::Half *input, at::Half *d_output, at::Half *d_weight, int in_dim, int hidden_dim, int out_dim); template void wgrad_gemm_accum_fp16_cuda(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim); diff --git a/csrc/megatron/fused_weight_gradient_dense_cuda.cu b/csrc/megatron/fused_weight_gradient_dense_cuda.cu index f2f762eb5..220064124 100644 --- a/csrc/megatron/fused_weight_gradient_dense_cuda.cu +++ b/csrc/megatron/fused_weight_gradient_dense_cuda.cu @@ -361,7 +361,7 @@ template void wgrad_gemm_accum_fp32_cuda(float *input, float *d_output, f void wgrad_gemm_accum_fp32_cuda_stub( at::Tensor &input, at::Tensor &d_output, - at::Tensor &d_weight) + at::Tensor &d_weight) { at::Tensor input_2d, d_output_2d; // input tensor: collapse to the first dim diff --git a/csrc/megatron/generic_scaled_masked_softmax.h b/csrc/megatron/generic_scaled_masked_softmax.h index 79fbc561d..aa52b481b 100644 --- a/csrc/megatron/generic_scaled_masked_softmax.h +++ b/csrc/megatron/generic_scaled_masked_softmax.h @@ -66,12 +66,12 @@ __device__ __forceinline__ acc_t warp_reduce_new(acc_t val) { template __global__ void scaled_masked_softmax_warp_backward_new( output_t *gradInput, //[batches, attn_heads, q_len, k_len] - input_t *grad, + input_t *grad, const input_t *output, //[batches, attn_heads, q_len, k_len] - acc_t scale, + acc_t scale, int element_count) { - int threads_per_block = blockDim.x; + int threads_per_block = blockDim.x; //the first element_count*2 elements are used for cache, the last 128 is used for reduction extern __shared__ acc_t shared_data[]; input_t *local_data = (input_t *)shared_data; @@ -86,7 +86,7 @@ __global__ void scaled_masked_softmax_warp_backward_new( int local_idx = threadIdx.x; int lane = threadIdx.x % C10_WARP_SIZE; int wid = threadIdx.x / C10_WARP_SIZE; - int warps_per_thread_block = threads_per_block / C10_WARP_SIZE; + int warps_per_thread_block = threads_per_block / C10_WARP_SIZE; // load the data to local data acc_t val = 0.0; @@ -99,7 +99,7 @@ __global__ void scaled_masked_softmax_warp_backward_new( __syncthreads(); } - // find the sum + // find the sum for (int i = local_idx; i < (element_count - 1) / C10_WARP_SIZE + 1; i += threads_per_block){ shared[i] = 0.0; } @@ -147,7 +147,7 @@ __global__ void scaled_masked_softmax_warp_backward_new( val = shared[0]; #pragma unroll for (int i = local_idx; i < element_count; i += threads_per_block){ - gradInput[offset + i] = (output_t)(scale*(local_data[i] - output_data[i]*val)); + gradInput[offset + i] = (output_t)(scale*(local_data[i] - output_data[i]*val)); } } @@ -155,12 +155,12 @@ __global__ void scaled_masked_softmax_warp_backward_new( template void dispatch_scaled_masked_softmax_backward_new( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, int batches, int attn_heads) { @@ -186,33 +186,33 @@ void dispatch_scaled_masked_softmax_backward_new( * Extended softmax (from native aten pytorch) with following additional features * 1) input scaling * 2) Explicit masking - */ + */ template __global__ void scaled_masked_softmax_warp_forward_new( - output_t *dst, + output_t *dst, const input_t *src, - const uint8_t *mask, - const acc_t scale, + const uint8_t *mask, + const acc_t scale, int query_len, // query_len int attn_heads, int element_count, // key_len - int pad_batches) // mask batch size + int pad_batches) // mask batch size { // min threawds_per_block has to be bigger than 128 - int threads_per_block = blockDim.x; + int threads_per_block = blockDim.x; // the first element_count is used for cache, the last 128 is used for reduction extern __shared__ acc_t local_data[]; // maximum shared cached 128, enough for 4096 elements reduction into 4096/32= 128 elements acc_t *shared = &(local_data[element_count]); - // number of 1024 threads reductions + // number of 1024 threads reductions int num_reductions = (element_count - 1) / threads_per_block + 1; int offset = blockIdx.x * element_count; int mask_offset; - int query_id = blockIdx.x % query_len; + int query_id = blockIdx.x % query_len; if (pad_batches == 1){ - // broadcaste the mask tensor - mask_offset = query_id * element_count; + // broadcaste the mask tensor + mask_offset = query_id * element_count; } else{ int mask_batch_id = blockIdx.x / attn_heads / query_len; @@ -222,7 +222,7 @@ __global__ void scaled_masked_softmax_warp_forward_new( int local_idx = threadIdx.x; int lane = threadIdx.x % C10_WARP_SIZE; int wid = threadIdx.x / C10_WARP_SIZE; - int warps_per_thread_block = threads_per_block / C10_WARP_SIZE; + int warps_per_thread_block = threads_per_block / C10_WARP_SIZE; // load the data to local data for (int i = local_idx; i < element_count; i += threads_per_block) @@ -300,7 +300,7 @@ __global__ void scaled_masked_softmax_warp_forward_new( local_data[i] = std::exp(local_data[i] - reduced_val); } - // find the sum + // find the sum for (int i = local_idx; i < (element_count - 1) / C10_WARP_SIZE + 1; i += threads_per_block){ shared[i] = 0.0; } @@ -356,12 +356,12 @@ __global__ void scaled_masked_softmax_warp_forward_new( template void dispatch_scaled_masked_softmax_forward_new( - output_t *dst, - const input_t *src, + output_t *dst, + const input_t *src, const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, + const input_t scale, + int query_seq_len, + int key_seq_len, int batches, int attn_heads, int pad_batches) diff --git a/csrc/megatron/generic_scaled_masked_softmax_cuda.cu b/csrc/megatron/generic_scaled_masked_softmax_cuda.cu index 93cd94b30..68f55afdd 100644 --- a/csrc/megatron/generic_scaled_masked_softmax_cuda.cu +++ b/csrc/megatron/generic_scaled_masked_softmax_cuda.cu @@ -44,9 +44,9 @@ torch::Tensor fwd_cuda( TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - // Output + // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = + torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); // Softmax Intermediate Result Ptr @@ -72,8 +72,8 @@ torch::Tensor fwd_cuda( } torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, float scale_factor) { auto output_grads = output_grads_.contiguous(); @@ -86,7 +86,7 @@ torch::Tensor bwd_cuda( const int key_seq_len = output_grads.size(3); auto act_options = output_grads.options(); - torch::Tensor input_grad = + torch::Tensor input_grad = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); void* output_grads_ptr = static_cast(output_grads.data_ptr()); @@ -96,8 +96,8 @@ torch::Tensor bwd_cuda( output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward", dispatch_scaled_masked_softmax_backward_new( - reinterpret_cast(static_cast(input_grad.data_ptr())), - reinterpret_cast(output_grads_ptr), + reinterpret_cast(static_cast(input_grad.data_ptr())), + reinterpret_cast(output_grads_ptr), reinterpret_cast(softmax_results.data_ptr()), scale_factor, query_seq_len, @@ -105,7 +105,7 @@ torch::Tensor bwd_cuda( batches, attn_heads); ); - + //backward pass is completely in-place return input_grad; } diff --git a/csrc/megatron/scaled_masked_softmax.h b/csrc/megatron/scaled_masked_softmax.h index 2674e1f54..1fba0d1b6 100644 --- a/csrc/megatron/scaled_masked_softmax.h +++ b/csrc/megatron/scaled_masked_softmax.h @@ -207,18 +207,18 @@ __global__ void scaled_softmax_warp_forward( * Extended softmax (from native aten pytorch) with following additional features * 1) input scaling * 2) Explicit masking - */ + */ template __global__ void scaled_masked_softmax_warp_forward( - output_t *dst, + output_t *dst, const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, int element_count, - int pad_batches) + int pad_batches) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_forward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; @@ -227,7 +227,7 @@ __global__ void scaled_masked_softmax_warp_forward( constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) + // gridDim/blockIdx = (seq_len, attn_heads, batches) int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; int pad_first_batch = 0; if (pad_batches != 1) { // bert style @@ -246,8 +246,8 @@ __global__ void scaled_masked_softmax_warp_forward( int local_idx = threadIdx.x; long int thread_offset_src_dst = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - long int thread_offset_mask = pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - src += thread_offset_src_dst; + long int thread_offset_mask = pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + src += thread_offset_src_dst; dst += thread_offset_src_dst; mask += thread_offset_mask; @@ -329,24 +329,24 @@ __global__ void scaled_masked_softmax_warp_forward( for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); + copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; - } + } } } } template __global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, + output_t *gradInput, + input_t *grad, const input_t *output, - acc_t scale, - int micro_batch_size, + acc_t scale, + int micro_batch_size, int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; @@ -355,9 +355,9 @@ __global__ void scaled_masked_softmax_warp_backward( constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) + // gridDim/blockIdx = (seq_len, attn_heads, batches) int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - + // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = micro_batch_size - first_batch; @@ -397,10 +397,10 @@ __global__ void scaled_masked_softmax_warp_backward( for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; } - } + } } } - + acc_t sum[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { @@ -428,7 +428,7 @@ __global__ void scaled_masked_softmax_warp_backward( out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } + } } } } @@ -551,12 +551,12 @@ void dispatch_scaled_softmax_forward( template void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, + output_t *dst, + const input_t *src, const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, + const input_t scale, + int query_seq_len, + int key_seq_len, int batches, int attn_heads, int pad_batches) @@ -645,12 +645,12 @@ void dispatch_scaled_masked_softmax_forward( template void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, int batches, int attn_heads) { diff --git a/csrc/megatron/scaled_masked_softmax_cpu.cpp b/csrc/megatron/scaled_masked_softmax_cpu.cpp index dd471a0bb..08e5b46f8 100644 --- a/csrc/megatron/scaled_masked_softmax_cpu.cpp +++ b/csrc/megatron/scaled_masked_softmax_cpu.cpp @@ -22,12 +22,12 @@ namespace fused_softmax { namespace scaled_masked_softmax { torch::Tensor fwd_cuda( - torch::Tensor const& input, + torch::Tensor const& input, torch::Tensor const& mask, float scale_factor); torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, + torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); @@ -43,7 +43,7 @@ torch::Tensor fwd( float scale_factor) { AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), + (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); @@ -51,7 +51,7 @@ torch::Tensor fwd( } torch::Tensor bwd( - torch::Tensor const& output_grads, + torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { @@ -59,10 +59,10 @@ torch::Tensor bwd( AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), + (output_grads.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), + (softmax_results.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); @@ -81,8 +81,8 @@ int get_batch_per_block( } // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, "Self Multihead Attention scaled, time masked softmax -- Forward."); m.def("backward", diff --git a/csrc/megatron/scaled_masked_softmax_cuda.cu b/csrc/megatron/scaled_masked_softmax_cuda.cu index 053d071ed..31d0f2759 100644 --- a/csrc/megatron/scaled_masked_softmax_cuda.cu +++ b/csrc/megatron/scaled_masked_softmax_cuda.cu @@ -51,9 +51,9 @@ torch::Tensor fwd_cuda( TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - // Output + // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = + torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); // Softmax Intermediate Result Ptr @@ -79,10 +79,10 @@ torch::Tensor fwd_cuda( } torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, float scale_factor) { - + auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -99,8 +99,8 @@ torch::Tensor bwd_cuda( output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward", dispatch_scaled_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), reinterpret_cast(softmax_results.data_ptr()), scale_factor, query_seq_len, @@ -108,7 +108,7 @@ torch::Tensor bwd_cuda( batches, attn_heads); ); - + //backward pass is completely in-place return output_grads; } diff --git a/csrc/megatron/scaled_softmax_cpu.cpp b/csrc/megatron/scaled_softmax_cpu.cpp index c8f6d28cc..89fc476b8 100644 --- a/csrc/megatron/scaled_softmax_cpu.cpp +++ b/csrc/megatron/scaled_softmax_cpu.cpp @@ -23,11 +23,11 @@ namespace fused_softmax { namespace scaled_softmax { torch::Tensor fwd_cuda( - torch::Tensor const& input, + torch::Tensor const& input, float scale_factor); torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, + torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); @@ -36,14 +36,14 @@ torch::Tensor fwd( float scale_factor) { TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), + (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); return fwd_cuda(input, scale_factor); } torch::Tensor bwd( - torch::Tensor const& output_grads, + torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { @@ -51,10 +51,10 @@ torch::Tensor bwd( TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); TORCH_CHECK((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), + (output_grads.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), + (softmax_results.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); @@ -65,10 +65,10 @@ torch::Tensor bwd( } // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_softmax::fwd, + m.def("forward", + &multihead_attn::fused_softmax::scaled_softmax::fwd, "Self Multihead Attention scaled, softmax -- Forward.", py::call_guard()); - m.def("backward", + m.def("backward", &multihead_attn::fused_softmax::scaled_softmax::bwd, "Self Multihead Attention scaled, softmax -- Backward.", py::call_guard()); } diff --git a/csrc/megatron/scaled_softmax_cuda.cu b/csrc/megatron/scaled_softmax_cuda.cu index 1bcaff36b..52d54a396 100644 --- a/csrc/megatron/scaled_softmax_cuda.cu +++ b/csrc/megatron/scaled_softmax_cuda.cu @@ -40,9 +40,9 @@ torch::Tensor fwd_cuda( TORCH_INTERNAL_ASSERT(key_seq_len <= 16384); TORCH_INTERNAL_ASSERT(query_seq_len > 1); - // Output + // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = + torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); // Softmax Intermediate Result Ptr @@ -65,10 +65,10 @@ torch::Tensor fwd_cuda( } torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, float scale_factor) { - + auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -85,8 +85,8 @@ torch::Tensor bwd_cuda( output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward", dispatch_scaled_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), reinterpret_cast(softmax_results.data_ptr()), scale_factor, query_seq_len, @@ -94,7 +94,7 @@ torch::Tensor bwd_cuda( batches, attn_heads); ); - + //backward pass is completely in-place return output_grads; } diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax.h b/csrc/megatron/scaled_upper_triang_masked_softmax.h index 562350af2..096d61b0e 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax.h +++ b/csrc/megatron/scaled_upper_triang_masked_softmax.h @@ -34,7 +34,7 @@ __device__ __inline__ void copy_vector(c10::BFloat16 *dst, con template <> __device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - + template <> __device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } @@ -113,14 +113,14 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { */ template __global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_forward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; @@ -129,7 +129,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; + int local_seq = blockIdx.x + 1; int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; // micro_batch_size might not be a multiple of WARP_BATCH. Check how @@ -195,7 +195,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( if (it < warp_iteration_limit) { elements[i][it] = std::exp((elements[i][it] - max_value[i])); sum[i] += elements[i][it]; - } + } } } warp_reduce(sum); @@ -212,7 +212,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( if (element_index < local_seq) { - #pragma unroll + #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < local_seq) { out[element] = elements[i][it + element] / sum[i]; @@ -225,22 +225,22 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); } else { break; - } + } } } } template __global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, + output_t *gradInput, + input_t *grad, const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, + acc_t scale, + int micro_batch_size, + int stride, int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; @@ -249,8 +249,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - + int local_seq = blockIdx.x + 1; + // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = micro_batch_size - first_batch; @@ -297,7 +297,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( } } } - + acc_t sum[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { @@ -325,7 +325,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); - } + } } } } @@ -334,11 +334,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( template void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, + output_t *dst, + const input_t *src, + const input_t scale, + int softmax_elements, + int softmax_elements_stride, int attn_batches) { TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 16384 ); @@ -436,12 +436,12 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( template void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int softmax_elements, + int softmax_elements_stride, int attn_batches) { TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 16384 ); diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax_cpu.cpp b/csrc/megatron/scaled_upper_triang_masked_softmax_cpu.cpp index 12cec7f67..ddf870987 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax_cpu.cpp +++ b/csrc/megatron/scaled_upper_triang_masked_softmax_cpu.cpp @@ -22,25 +22,25 @@ namespace fused_softmax { namespace scaled_upper_triang_masked_softmax { torch::Tensor fwd_cuda( - torch::Tensor const& input, + torch::Tensor const& input, float scale_factor); torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, + torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), + (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); return fwd_cuda(input, scale_factor); } torch::Tensor bwd( - torch::Tensor const& output_grads, + torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { @@ -48,10 +48,10 @@ torch::Tensor bwd( AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), + (output_grads.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), + (softmax_results.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); @@ -62,10 +62,10 @@ torch::Tensor bwd( } // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", + m.def("forward", &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", + m.def("backward", &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, "Self Multihead Attention scaled, time masked softmax -- Backward."); } diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu b/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu index 7cec7f8e3..3a0ce7cc3 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu +++ b/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu @@ -29,7 +29,7 @@ namespace fused_softmax { namespace scaled_upper_triang_masked_softmax { torch::Tensor fwd_cuda( - torch::Tensor const& input, + torch::Tensor const& input, float scale_factor) { // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] @@ -37,9 +37,9 @@ torch::Tensor fwd_cuda( const int seq_len = input.size(1); TORCH_INTERNAL_ASSERT(seq_len <= 16384); - // Output + // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = + torch::Tensor softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options); // Softmax Intermediate Result Ptr @@ -59,13 +59,13 @@ torch::Tensor fwd_cuda( ); return softmax_results; } - + torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, float scale_factor) { - + auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -81,15 +81,15 @@ torch::Tensor bwd_cuda( output_grads_.scalar_type(), "dispatch_scaled_upper_triang_masked_softmax_backward", dispatch_scaled_upper_triang_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), reinterpret_cast(softmax_results.data_ptr()), scale_factor, seq_len, seq_len, attn_batches); ); - + //backward pass is completely in-place return output_grads; } diff --git a/csrc/mlp_cuda.cu b/csrc/mlp_cuda.cu index 1b67ad739..46cfb17af 100644 --- a/csrc/mlp_cuda.cu +++ b/csrc/mlp_cuda.cu @@ -1432,7 +1432,7 @@ int mlp_bp( #endif #endif #endif - + int* y_offsets = (int*)malloc(num_layers * sizeof(int)); get_y_offsets(batch_size, num_layers, output_features, y_offsets); diff --git a/csrc/multi_tensor_l2norm_kernel.cu b/csrc/multi_tensor_l2norm_kernel.cu index 66189112f..36edae335 100644 --- a/csrc/multi_tensor_l2norm_kernel.cu +++ b/csrc/multi_tensor_l2norm_kernel.cu @@ -436,9 +436,9 @@ void multi_tensor_norm_out_cuda( // logic, but keeping it simple for now auto ret = at::empty({1}, output.options()); - // Adding the following device guard since it happens sometimes that the - // tensors are on one device and the cuda stream is on another device which - // results in ILLEGAL MEM ACCESS error. + // Adding the following device guard since it happens sometimes that the + // tensors are on one device and the cuda stream is on another device which + // results in ILLEGAL MEM ACCESS error. const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); auto stream = at::cuda::getCurrentCUDAStream(); cleanup_v2<<>>( diff --git a/csrc/multi_tensor_lars.cu b/csrc/multi_tensor_lars.cu index bc9bbee2f..da9363da3 100644 --- a/csrc/multi_tensor_lars.cu +++ b/csrc/multi_tensor_lars.cu @@ -49,10 +49,10 @@ struct LARSFunctor bool wd_after_momentum, float scale, const bool is_skipped) { - + // Early exit if we don't need to do anything if (*noop_gmem) return; - + int tensor_loc = tl.block_to_tensor[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; @@ -137,7 +137,7 @@ struct LARSFunctor } weight_in[i] = static_cast(incoming_weights[ii]); - + // if necessary, write out an fp16 copy of the weights if(N == 4) model_weights_out[i] = static_cast(weight_in[i]); diff --git a/csrc/syncbn.cpp b/csrc/syncbn.cpp index 578a6e653..d98cfd074 100644 --- a/csrc/syncbn.cpp +++ b/csrc/syncbn.cpp @@ -4,12 +4,12 @@ #include // returns {mean,biased_var} -// implemented using welford +// implemented using welford std::vector welford_mean_var_CUDA(const at::Tensor input); // reduces array of mean/var across processes // returns global {mean,inv_std,biased_var} -// implemented using welford +// implemented using welford std::vector welford_parallel_CUDA(const at::Tensor mean_feature_nodes, const at::Tensor var_biased_feature_nodes, const at::Tensor numel, @@ -47,7 +47,7 @@ at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output, const at::Tensor count); // returns {mean, biased_var} -// implemented using welford +// implemented using welford // expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL std::vector welford_mean_var_c_last_CUDA(const at::Tensor input); diff --git a/csrc/welford.cu b/csrc/welford.cu index fabee1999..5a07e71f1 100644 --- a/csrc/welford.cu +++ b/csrc/welford.cu @@ -1476,7 +1476,7 @@ at::Tensor batchnorm_backward_c_last_CUDA( stride); ); } - + return grad_input; } diff --git a/docs/source/amp.rst b/docs/source/amp.rst index 4bc140518..ea9de7248 100644 --- a/docs/source/amp.rst +++ b/docs/source/amp.rst @@ -187,10 +187,10 @@ In order to get bitwise accuracy, we recommend the following workflow:: # Initialization opt_level = 'O1' model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) - + # Train your model ... - + # Save checkpoint checkpoint = { 'model': model.state_dict(), @@ -199,17 +199,17 @@ In order to get bitwise accuracy, we recommend the following workflow:: } torch.save(checkpoint, 'amp_checkpoint.pt') ... - + # Restore model = ... optimizer = ... checkpoint = torch.load('amp_checkpoint.pt') - + model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) amp.load_state_dict(checkpoint['amp']) - + # Continue training ... diff --git a/docs/source/fp16_utils.rst b/docs/source/fp16_utils.rst index b6b3da5f8..c8c3232ce 100644 --- a/docs/source/fp16_utils.rst +++ b/docs/source/fp16_utils.rst @@ -4,13 +4,13 @@ apex.fp16_utils =================================== -This submodule contains utilities designed to streamline the mixed precision training recipe -presented by NVIDIA `on Parallel Forall`_ and in GTC 2018 Sessions -`Training Neural Networks with Mixed Precision: Theory and Practice`_ and +This submodule contains utilities designed to streamline the mixed precision training recipe +presented by NVIDIA `on Parallel Forall`_ and in GTC 2018 Sessions +`Training Neural Networks with Mixed Precision: Theory and Practice`_ and `Training Neural Networks with Mixed Precision: Real Examples`_. For Pytorch users, Real Examples in particular is recommended. -Full runnable Python scripts demonstrating ``apex.fp16_utils`` +Full runnable Python scripts demonstrating ``apex.fp16_utils`` can be found on the Github page: | `Simple FP16_Optimizer demos`_ diff --git a/docs/source/index.rst b/docs/source/index.rst index c7efc1681..fc054c83e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -45,7 +45,7 @@ Some other useful material, including GTC 2019 and Pytorch DevCon 2019 Slides, c fp16_util .. RNN - + Indices and tables ================== diff --git a/examples/dcgan/main_amp.py b/examples/dcgan/main_amp.py index be1a2894f..bf6f03886 100644 --- a/examples/dcgan/main_amp.py +++ b/examples/dcgan/main_amp.py @@ -12,7 +12,7 @@ import torchvision.transforms as transforms import torchvision.utils as vutils -try: +try: from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") diff --git a/examples/imagenet/README.md b/examples/imagenet/README.md index 257d4a78d..698c58f1f 100644 --- a/examples/imagenet/README.md +++ b/examples/imagenet/README.md @@ -158,7 +158,7 @@ each process prior to initializing your model or any other tensors. More information can be found in the docs for the Pytorch multiprocess launcher module [torch.distributed.launch](https://pytorch.org/docs/stable/distributed.html#launch-utility). -`main_amp.py` is written to interact with +`main_amp.py` is written to interact with [torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility), which spawns multiprocess jobs using the following syntax: ``` diff --git a/examples/simple/distributed/README.md b/examples/simple/distributed/README.md index 0d939cbbf..1939ef5d0 100644 --- a/examples/simple/distributed/README.md +++ b/examples/simple/distributed/README.md @@ -3,7 +3,7 @@ [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#distributeddataparallel) and the Pytorch multiprocess launcher script, [torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility). -The use of `Amp` with DistributedDataParallel does not need to change from ordinary +The use of `Amp` with DistributedDataParallel does not need to change from ordinary single-process use. The only gotcha is that wrapping your model with `DistributedDataParallel` must come after the call to `amp.initialize`. Test via ```bash diff --git a/op_builder/__init__.py b/op_builder/__init__.py index 726ec6f4d..a6b4bc8a7 100644 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -45,7 +45,7 @@ def _builder(): return builder # this is for the import statement such as 'from apex.op_builder import FusedLayerNormBuilder' to work -# reflect builder names and add builder closure, such as 'apex.op_builder.FusedLayerNormBuilder()' creates op builder +# reflect builder names and add builder closure, such as 'apex.op_builder.FusedLayerNormBuilder()' creates op builder for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(this_module.__file__)]): if module_name != 'all_ops' and module_name != 'builder': module = importlib.import_module(f".{module_name}", package=op_builder_dir) diff --git a/op_builder/all_ops.py b/op_builder/all_ops.py index e18dbdd71..403f7e12b 100644 --- a/op_builder/all_ops.py +++ b/op_builder/all_ops.py @@ -83,5 +83,5 @@ def get_op_builder(self, class_name): # append builder to __op_builders__ list builder = builder_utils.create_op_builder(member_name) __op_builders__.append(builder) - + ALL_OPS = {op.name: op for op in __op_builders__ if op is not None} \ No newline at end of file diff --git a/op_builder/amp_C.py b/op_builder/amp_C.py index 41f029fcb..dc14d01a3 100644 --- a/op_builder/amp_C.py +++ b/op_builder/amp_C.py @@ -33,7 +33,7 @@ def sources(self): def include_paths(self): return ['csrc/'] - + def cxx_args(self): args = super().cxx_args() return args + self.version_dependent_macros() diff --git a/op_builder/apex_C.py b/op_builder/apex_C.py index b02526e77..8f2cac843 100644 --- a/op_builder/apex_C.py +++ b/op_builder/apex_C.py @@ -19,7 +19,7 @@ def sources(self): def include_paths(self): return ['csrc/' ] - + def libraries_args(self): args = super().libraries_args() return args \ No newline at end of file diff --git a/op_builder/builder.py b/op_builder/builder.py index 60e490b2b..3f3c32b15 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -586,6 +586,7 @@ def jit_load(self, verbose=True): if self.is_rocm_pytorch(): cxx_args.append("-D__HIP_PLATFORM_AMD__=1") + cxx_args.append("-DUSE_ROCM") os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) @@ -692,10 +693,10 @@ def version_dependent_macros(self): version_ge_1_5 = [] if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): version_ge_1_5 = ['-DVERSION_GE_1_5'] - + version_dependent_macro_args = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 if self.is_rocm_pytorch() and (self.torch_version()[0] >= 6): - version_dependent_macro_args += ["-DHIPBLAS_V2"] + version_dependent_macro_args += ["-DHIPBLAS_V2"] return version_dependent_macro_args @@ -781,6 +782,7 @@ def nvcc_args(self): args += [ '-std=c++17', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__', '-U__HIP_NO_HALF2_OPERATORS__', + '-DUSE_ROCM', '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR ] diff --git a/op_builder/distributed_adam.py b/op_builder/distributed_adam.py index ef453bee9..700cb7c63 100644 --- a/op_builder/distributed_adam.py +++ b/op_builder/distributed_adam.py @@ -21,7 +21,7 @@ def sources(self): def include_paths(self): return ['contrib/csrc/', 'csrc'] - + def cxx_args(self): args = super().cxx_args() return args + self.version_dependent_macros() diff --git a/op_builder/distributed_lamb.py b/op_builder/distributed_lamb.py index 74d77d129..7e5f848ca 100644 --- a/op_builder/distributed_lamb.py +++ b/op_builder/distributed_lamb.py @@ -21,7 +21,7 @@ def sources(self): def include_paths(self): return ['contrib/csrc/', 'csrc'] - + def cxx_args(self): args = super().cxx_args() return args + self.version_dependent_macros() diff --git a/op_builder/fast_multihead_attn.py b/op_builder/fast_multihead_attn.py index 0f2f8b52f..8ed6251eb 100644 --- a/op_builder/fast_multihead_attn.py +++ b/op_builder/fast_multihead_attn.py @@ -29,13 +29,13 @@ def include_paths(self): return ['csrc/', 'contrib/csrc/', 'contrib/csrc/multihead_attn'] - + def cxx_args(self): args = super().cxx_args() return args + self.version_dependent_macros() + self.generator_args() def nvcc_args(self): - nvcc_flags = ['-O3'] + self.version_dependent_macros() + self.generator_args() + nvcc_flags = ['-O3'] + self.version_dependent_macros() + self.generator_args() if not self.is_rocm_pytorch(): nvcc_flags += ['-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', diff --git a/op_builder/focal_loss.py b/op_builder/focal_loss.py index 98a21330a..c32b5423d 100644 --- a/op_builder/focal_loss.py +++ b/op_builder/focal_loss.py @@ -20,7 +20,7 @@ def sources(self): def include_paths(self): return ['contrib/csrc/' ] - + def cxx_args(self): args = super().cxx_args() return args + self.version_dependent_macros() diff --git a/op_builder/fused_adam.py b/op_builder/fused_adam.py index f335368d8..c34730fa6 100644 --- a/op_builder/fused_adam.py +++ b/op_builder/fused_adam.py @@ -21,7 +21,7 @@ def sources(self): def include_paths(self): return ['contrib/csrc/', 'csrc'] - + def cxx_args(self): args = super().cxx_args() return args + self.version_dependent_macros() diff --git a/op_builder/fused_conv_bias_relu.py b/op_builder/fused_conv_bias_relu.py index 997cfb32d..47935f078 100644 --- a/op_builder/fused_conv_bias_relu.py +++ b/op_builder/fused_conv_bias_relu.py @@ -20,7 +20,7 @@ def sources(self): return ["contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"] def include_paths(self): - paths = ['contrib/csrc/'] + paths = ['contrib/csrc/'] if not self.is_rocm_pytorch(): paths.append("apex/contrib/csrc/cudnn-frontend/include") return paths diff --git a/op_builder/fused_lamb.py b/op_builder/fused_lamb.py index 02a0b6fe7..500ce46f3 100644 --- a/op_builder/fused_lamb.py +++ b/op_builder/fused_lamb.py @@ -22,7 +22,7 @@ def sources(self): def include_paths(self): return ['contrib/csrc/', 'csrc'] - + def cxx_args(self): args = super().cxx_args() return args + self.version_dependent_macros() diff --git a/op_builder/mlp.py b/op_builder/mlp.py index c6a177721..c69ce04f9 100644 --- a/op_builder/mlp.py +++ b/op_builder/mlp.py @@ -29,4 +29,4 @@ def nvcc_args(self): nvcc_flags = ['-O3'] + self.version_dependent_macros() if self.is_rocm_pytorch(): nvcc_flags.extend(self.backward_pass_guard_args()) - return nvcc_flags \ No newline at end of file + return nvcc_flags \ No newline at end of file diff --git a/op_builder/nccl_allocator.py b/op_builder/nccl_allocator.py index 320e76476..13f326170 100644 --- a/op_builder/nccl_allocator.py +++ b/op_builder/nccl_allocator.py @@ -19,7 +19,7 @@ def sources(self): def include_paths(self): return ['contrib/csrc/'] - + def cxx_args(self): args = super().cxx_args() return args + self.version_dependent_macros() + self.generator_args() diff --git a/op_builder/nccl_p2p.py b/op_builder/nccl_p2p.py index 37772572e..d97588a13 100644 --- a/op_builder/nccl_p2p.py +++ b/op_builder/nccl_p2p.py @@ -20,7 +20,7 @@ def sources(self): def include_paths(self): return ['contrib/csrc/'] - + def cxx_args(self): args = super().cxx_args() return args + self.version_dependent_macros() + self.generator_args() \ No newline at end of file diff --git a/op_builder/peer_memory.py b/op_builder/peer_memory.py index c869f0be6..7f3a8e087 100644 --- a/op_builder/peer_memory.py +++ b/op_builder/peer_memory.py @@ -20,7 +20,7 @@ def sources(self): def include_paths(self): return ['contrib/csrc/'] - + def cxx_args(self): args = super().cxx_args() return args + self.version_dependent_macros() + self.generator_args() \ No newline at end of file diff --git a/op_builder/transducer_joint.py b/op_builder/transducer_joint.py index c17f60f7b..72c8bdc86 100644 --- a/op_builder/transducer_joint.py +++ b/op_builder/transducer_joint.py @@ -20,14 +20,14 @@ def sources(self): def include_paths(self): return ['contrib/csrc/', #it uses philox.cuh from contrib/csrc/multihead_attn - 'contrib/csrc/multihead_attn'] - + 'contrib/csrc/multihead_attn'] + def cxx_args(self): args = super().cxx_args() - return args + self.version_dependent_macros() + self.generator_args() + return args + self.version_dependent_macros() + self.generator_args() def nvcc_args(self): - nvcc_flags = ['-O3'] + self.version_dependent_macros() + self.generator_args() + nvcc_flags = ['-O3'] + self.version_dependent_macros() + self.generator_args() if not self.is_rocm_pytorch(): nvcc_flags += self.nvcc_threads_args() return nvcc_flags \ No newline at end of file diff --git a/op_builder/transducer_loss.py b/op_builder/transducer_loss.py index 53ae4eaac..db03e6698 100644 --- a/op_builder/transducer_loss.py +++ b/op_builder/transducer_loss.py @@ -19,13 +19,13 @@ def sources(self): def include_paths(self): return ['contrib/csrc/' ] - + def cxx_args(self): args = super().cxx_args() return args + self.version_dependent_macros() def nvcc_args(self): - nvcc_flags = ['-O3'] + self.version_dependent_macros() + nvcc_flags = ['-O3'] + self.version_dependent_macros() if not self.is_rocm_pytorch(): nvcc_flags += self.nvcc_threads_args() return nvcc_flags \ No newline at end of file diff --git a/scripts/jit_module.py b/scripts/jit_module.py index 7c4ef48c4..1349c725c 100644 --- a/scripts/jit_module.py +++ b/scripts/jit_module.py @@ -12,7 +12,7 @@ def __init__(self): self.compatability_folder = "compatibility" def get_module_name(self, builder_file_name): - #open builder file and read the NAME attribute + #open builder file and read the NAME attribute with open(os.path.join(self.op_builder_folder, f"{builder_file_name}.py"), "r") as f: contents = f.read() for line in contents.split("\n"): @@ -36,7 +36,7 @@ def create_builder_class_name(self, module_name): new_name += part.capitalize() return f"{new_name}Builder" - + def create_build_var(self, module_name): return f"APEX_BUILD_{module_name.upper()}" @@ -137,7 +137,7 @@ def create_builder(self, module_name): builder_class_name = self.create_builder_class_name(module_name) build_var = self.create_build_var(module_name) - + if len(sources) == 0: sources_list = [] sources_list_string = "[]" @@ -179,7 +179,7 @@ def create_builder(self, module_name): f.write(f" # Please mention the full path of the source files\n") f.write(f" # e.g. ['csrc/fused_dense_base.cpp', 'csrc/fused_dense_cuda.cu']\n") f.write(f" def sources(self):\n") - f.write(f" return {sources_list_string}\n") + f.write(f" return {sources_list_string}\n") f.write(f"\n") f.write(f" # Required to override. Return the list of include directories\n") f.write(f" # Please mention the full path of the include directories\n") diff --git a/setup.py b/setup.py index febfe94a9..0c749d2e3 100644 --- a/setup.py +++ b/setup.py @@ -9,10 +9,10 @@ import torch from torch.utils.cpp_extension import ( - BuildExtension, - CppExtension, - CUDAExtension, - CUDA_HOME, + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, ROCM_HOME, load, ) @@ -202,7 +202,7 @@ def is_op_build_included(op_name): install_ops[op_name] = True ext_modules.append(builder.builder()) -print(f'Install Ops={install_ops}') +print(f'Install Ops={install_ops}') # Write out version/git info. git_hash_cmd = shlex.split("bash -c \"git rev-parse --short HEAD\"") diff --git a/tests/L0/run_amp/test_basic_casts.py b/tests/L0/run_amp/test_basic_casts.py index 7ec254e42..0e478cbda 100644 --- a/tests/L0/run_amp/test_basic_casts.py +++ b/tests/L0/run_amp/test_basic_casts.py @@ -74,7 +74,7 @@ def setUp(self): def tearDown(self): self.handle._deactivate() common_reset(self) - + def test_linear_is_half(self): self._test_linear(ALWAYS_HALF) diff --git a/tests/L0/run_amp/test_cache.py b/tests/L0/run_amp/test_cache.py index c5b33ade0..c548b781d 100644 --- a/tests/L0/run_amp/test_cache.py +++ b/tests/L0/run_amp/test_cache.py @@ -74,23 +74,23 @@ def train_eval_train_test(self, module, t, opt_level): _amp_state.allow_incoming_model_not_fp32 = True model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level, verbosity=0) _amp_state.allow_incoming_model_not_fp32 = False - + def training_step(): for param in model.parameters(): param.grad = None - + loss = model(self.x).sum() _amp_state.loss_scalers[0]._loss_scale = 4.0 with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() - + self.assertEqual(len([p.grad for p in model.parameters() if p.grad is not None]), 1) self.assertEqual(model.weight.grad.type(), model.weight.type()) - + reference_grad = get_reference_grad(self.x, model.weight, model.ops) - + # Currently there's no difference in the allclose calls, so no need for branching, - # but I'm keeping this in case we want different tolerances for fp16 and fp32 checks. + # but I'm keeping this in case we want different tolerances for fp16 and fp32 checks. if model.weight.grad.type() == "torch.cuda.HalfTensor": self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad)) elif model.weight.grad.type() == "torch.cuda.BFloat16Tensor": @@ -101,19 +101,19 @@ def training_step(): raise RuntimeError("model.weight.grad.type = {}".format(model.weight.grad.type())) model.weight.data -= 1. - + # Simulates first epoch training_step() - + # Simulates eval with torch.no_grad(): loss = model(self.x).sum() - + # Simulates resuming training after eval training_step() _amp_state.handle._deactivate() - + # I could easily have these as a set of for loops in a single test, # instead of going for granularity. def test_whitelist_module_fp16_weight(self): diff --git a/tests/L0/run_amp/test_checkpointing.py b/tests/L0/run_amp/test_checkpointing.py index ff7ee884d..107fecf0e 100644 --- a/tests/L0/run_amp/test_checkpointing.py +++ b/tests/L0/run_amp/test_checkpointing.py @@ -69,7 +69,7 @@ def compare_models(self, modelA, modelB, test_setup=''): msg='Parameters in state_dices not equal.' + 'key: {}\nparam: {}\nrestored: {}\ndiff: {} for {}'.format( key, paramA, paramB, paramA - paramB, test_setup)) - + def test_restoring(self): nb_epochs = 10 nb_epochs_restore = nb_epochs // 2 @@ -225,11 +225,11 @@ def test_loss_scale_decrease(self): self.assertEqual(scaler['loss_scale'], init_ls / 2**factor) unskipped_target = 0 self.assertEqual(scaler['unskipped'], unskipped_target) - + if opt_level != "O0": _amp_state.handle._deactivate() - + def test_state_dict(self): for opt_level in self.test_opt_levels: diff --git a/tests/L0/run_amp/test_rnn.py b/tests/L0/run_amp/test_rnn.py index 02fb301d3..7ca51f479 100644 --- a/tests/L0/run_amp/test_rnn.py +++ b/tests/L0/run_amp/test_rnn.py @@ -40,7 +40,7 @@ def run_cell_test(self, cell, state_tuple=False): outputs[-1].float().sum().backward() for i, x in enumerate(xs): self.assertEqual(x.grad.dtype, x.dtype) - + def test_rnn_cell_is_half(self): cell = nn.RNNCell(self.h, self.h) self.run_cell_test(cell) diff --git a/tests/L0/run_fused_dense/test_gelu.py b/tests/L0/run_fused_dense/test_gelu.py index 913fec7ab..73878811b 100644 --- a/tests/L0/run_fused_dense/test_gelu.py +++ b/tests/L0/run_fused_dense/test_gelu.py @@ -1,4 +1,4 @@ -from apex import fused_dense +from apex import fused_dense import torch import torch.nn.functional as F import unittest diff --git a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py index 61b64849a..0cbe22081 100644 --- a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py +++ b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py @@ -134,7 +134,7 @@ def _test_fused_rms_norm( ) def test_layer_norm_regular(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) - + @common_utils.parametrize( "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", list(product((16, 65536), (True, False), (True,), (False,), (torch.float,), (True, False))) @@ -148,7 +148,7 @@ def test_layer_norm_elemwise(self, batch_size, contiguous, elementwise_affine, m ) def test_layer_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) - + @common_utils.parametrize( "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", list(product((16,), (True, False), (True,), (False,), (torch.half,), (True, False))) @@ -156,13 +156,13 @@ def test_layer_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixe def test_layer_norm_half(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=1e-3, atol=1e-3), bwd_thresholds=dict(rtol=1e-3, atol=1e-3)) - + @common_utils.parametrize( "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", list(product((16,), (True, False), (True,), (False,), (torch.bfloat16,), (True, False))) ) def test_layer_norm_bfloat16(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): - self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=1.6e-2, atol=3e-4), bwd_thresholds=dict(rtol=1.6e-2, atol=3e-3)) # rms norm tests @@ -188,7 +188,7 @@ def test_rms_norm_elemwise(self, batch_size, contiguous, elementwise_affine, mix def test_rms_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, bwd_thresholds=dict(rtol=2e-3, atol=2e-4)) - + @common_utils.parametrize( "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", list(product((16,), (True, False), (True,), (False,), (torch.half,), (True, False))) @@ -196,13 +196,13 @@ def test_rms_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_ def test_rms_norm_half(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)) - + @common_utils.parametrize( "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", list(product((16,), (True, False), (True,), (False,), (torch.bfloat16,), (True, False))) ) def test_rms_norm_bfloat16(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): - self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=1.6e-2, atol=3e-4), bwd_thresholds=dict(rtol=1.6e-2, atol=3e-2)) @common_utils.parametrize( @@ -226,9 +226,9 @@ def test_autocast_fused_layer_norm(self, dtype, elementwise_affine, memory_effic with torch.amp.autocast('cuda', dtype=dtype): actual = fused(fused_x) tols = {'rtol': None, 'atol': None} if dtype == torch.half else bf16_fwd_thresholds - # original tests used torch.testing.assert_allclose, which disables dtype checking by default. + # original tests used torch.testing.assert_allclose, which disables dtype checking by default. # link to issue here: https://github.com/pytorch/pytorch/issues/61844 - torch.testing.assert_close(actual, expected, **tols, check_dtype=False) + torch.testing.assert_close(actual, expected, **tols, check_dtype=False) g_native = torch.rand_like(expected) with torch.no_grad(): @@ -253,10 +253,10 @@ def test_autocast_fused_rms_norm(self, dtype, elementwise_affine, memory_efficie batch_size = 16 normalized_shape = [32, 16] native = FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient, + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient, ).to(dtype=dtype) fused = FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient, + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient, ).cuda() native_x, fused_x = _prep_inputs(batch_size, normalized_shape, dtype) @@ -304,7 +304,7 @@ def test_rms_export(self): native_x, fused_x = _prep_inputs(batch_size, normalized_shape, torch.float32) self._verify_export(fused, fused_x) self._verify_export(fused_m, fused_x) - + def test_layer_norm_export(self): batch_size = 16 normalized_shape = [32, 16] @@ -365,7 +365,7 @@ def test_compile_fused_rms_norm(self, elementwise_affine): actual.backward(g_compiled) torch.testing.assert_close(eager_x.grad, compiled_x.grad) - + instantiate_device_type_tests(TestFusedLayerNorm, globals(), only_for=("cuda",)) if __name__ == "__main__": common_utils.run_tests() \ No newline at end of file diff --git a/tests/L0/run_optimizers/test_adam.py b/tests/L0/run_optimizers/test_adam.py index 9fd00cbea..8c2b512f7 100644 --- a/tests/L0/run_optimizers/test_adam.py +++ b/tests/L0/run_optimizers/test_adam.py @@ -46,7 +46,7 @@ def forward(self, x): y = self.relu4(y) y = self.fc3(y) y = self.relu5(y) - return y + return y @unittest.skipIf(not HAS_APEX, "`apex` is not found.") @@ -83,7 +83,7 @@ def testGradScaler(self): scaler.scale(loss).backward() scaler.step(self.optimizer) scaler.update() - + # DUT with torch.amp.autocast('cuda', enabled=True): y = self.model_(x) @@ -105,7 +105,7 @@ def testGradScaler(self): optimizer_.zero_grad() self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) - + def testGradScalerCapturable(self): params_ = [p for p in self.model_.parameters() if p.requires_grad] optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=True) @@ -126,7 +126,7 @@ def testGradScalerCapturable(self): scaler.scale(loss).backward() scaler.step(self.optimizer) scaler.update() - + # DUT with torch.amp.autocast('cuda', enabled=True): y = self.model_(x) @@ -212,7 +212,7 @@ def testNative(self): loss.backward() self.optimizer.step() - + # DUT y = self.model_(x) loss_ = ((gt_ - y) ** 2).mean() @@ -230,7 +230,7 @@ def testNative(self): # Init for next iteration self.optimizer.zero_grad() optimizer_.zero_grad() - + self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) @largeTensorTest('60GB', 'cuda') diff --git a/tests/L0/run_optimizers/test_fused_novograd.py b/tests/L0/run_optimizers/test_fused_novograd.py index fa94e7102..a3b5757e7 100755 --- a/tests/L0/run_optimizers/test_fused_novograd.py +++ b/tests/L0/run_optimizers/test_fused_novograd.py @@ -111,7 +111,7 @@ def step(self, closure=None): exp_avg.mul_(beta1).add_(grad) p.data.add_(exp_avg, alpha=-group['lr']) - + return loss @@ -124,10 +124,10 @@ def __init__(self, *args, **kwargs): # are expected to behave the same. self.options = {'lr':1e-3, 'betas':(0.95, 0), 'eps':1e-8, 'weight_decay':0, 'grad_averaging':False, 'amsgrad':False} - + self.tst_options = {'lr':1e-3, 'betas':(0.95, 0), 'eps':1e-8, - 'weight_decay':0, 'grad_averaging':False, 'amsgrad':False, - 'bias_correction':False, 'reg_inside_moment':True, + 'weight_decay':0, 'grad_averaging':False, 'amsgrad':False, + 'bias_correction':False, 'reg_inside_moment':True, 'norm_type':2, 'init_zero':False, 'set_grad_none':True} self.ref_optim = Novograd @@ -146,7 +146,7 @@ def test_multi_device(self): with torch.cuda.device(current_dev): torch.cuda.synchronize() self.gen_single_type_test(param_type=torch.float, device=tensor_dev) - + def test_multi_params(self): sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] diff --git a/tests/L0/run_test.py b/tests/L0/run_test.py index ed84fe956..6164fe757 100644 --- a/tests/L0/run_test.py +++ b/tests/L0/run_test.py @@ -26,7 +26,7 @@ "run_fused_layer_norm", "run_mlp", "run_fused_dense", - "run_transformer", + "run_transformer", ] #the tests that are run by default diff --git a/tests/L0/run_transformer/test_layers.py b/tests/L0/run_transformer/test_layers.py index 9f3066907..12d7fdc9a 100644 --- a/tests/L0/run_transformer/test_layers.py +++ b/tests/L0/run_transformer/test_layers.py @@ -399,7 +399,7 @@ def _row_parallel_linear_test_impl( dim=0, )[parallel_state.get_tensor_model_parallel_rank()], atol=1e-4, - rtol=1e-3 + rtol=1e-3 ) parallel_state.destroy_model_parallel() diff --git a/tests/L1/common/main_amp.py b/tests/L1/common/main_amp.py index 106a0f637..5eec14782 100644 --- a/tests/L1/common/main_amp.py +++ b/tests/L1/common/main_amp.py @@ -91,7 +91,7 @@ def fast_collate(batch): nump_array = np.rollaxis(nump_array, 2) tensor[i] += torch.from_numpy(nump_array) - + return tensor, targets best_prec1 = 0 @@ -99,7 +99,7 @@ def fast_collate(batch): # Let multi_tensor_applier be the canary in the coalmine # that verifies if the backend is what we think it is -assert multi_tensor_applier.available == args.has_ext +assert multi_tensor_applier.available == args.has_ext print("opt_level = {}".format(args.opt_level)) print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) @@ -149,7 +149,7 @@ def main(): model = model.cuda() # Scale learning rate based on global batch size - args.lr = args.lr*float(args.batch_size*args.world_size)/256. + args.lr = args.lr*float(args.batch_size*args.world_size)/256. if args.fused_adam: optimizer = optimizers.FusedAdam(model.parameters()) else: @@ -166,7 +166,7 @@ def main(): ) if args.distributed: - # By default, apex.parallel.DistributedDataParallel overlaps communication with + # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # model = DDP(model) # delay_allreduce delays all communication to the end of the backward pass. @@ -288,7 +288,7 @@ def preload(self): # else: self.next_input = self.next_input.float() self.next_input = self.next_input.sub_(self.mean).div_(self.std) - + def next(self): torch.cuda.current_stream().wait_stream(self.stream) input = self.next_input diff --git a/tests/distributed/DDP/ddp_race_condition_test.py b/tests/distributed/DDP/ddp_race_condition_test.py index 761a33595..a1f9466b4 100644 --- a/tests/distributed/DDP/ddp_race_condition_test.py +++ b/tests/distributed/DDP/ddp_race_condition_test.py @@ -51,7 +51,7 @@ def forward(self, input): # torch.cuda.nvtx.range_push("backward") loss.backward() # torch.cuda.nvtx.range_pop() - + # torch.cuda.nvtx.range_push("synchronize() + info") # torch.cuda.synchronize() print("i = {}".format(i)) diff --git a/tests/distributed/run_rocm_distributed.sh b/tests/distributed/run_rocm_distributed.sh index 322137bbd..b6a3ccd6d 100755 --- a/tests/distributed/run_rocm_distributed.sh +++ b/tests/distributed/run_rocm_distributed.sh @@ -40,7 +40,7 @@ python -m torch.distributed.launch --nproc_per_node=2 synced_batchnorm/two_gpu_t echo "Running syncbn python only tests" python synced_batchnorm/python_single_gpu_unit_test.py echo "Running syncbn batchnorm1d tests" -python synced_batchnorm/test_batchnorm1d.py +python synced_batchnorm/test_batchnorm1d.py #beware, you need a system with at least 4 gpus to test group_size&1 | tee ../../$LOG_FILE cd ../../ diff --git a/tests/jit_build/scripts/run.sh b/tests/jit_build/scripts/run.sh index aeb41fadd..5aa9ab94d 100644 --- a/tests/jit_build/scripts/run.sh +++ b/tests/jit_build/scripts/run.sh @@ -4,13 +4,13 @@ JIT_CONDITION="$2" echo $(pwd) WORKSPACE_DIR=/myworkspace -mkdir -p $WORKSPACE_DIR +mkdir -p $WORKSPACE_DIR -cd $WORKSPACE_DIR -git clone https://github.com/rocm/apex.git --recursive -cd apex +cd $WORKSPACE_DIR +git clone https://github.com/rocm/apex.git --recursive +cd apex git checkout Refactor_build -git submodule update --init --recursive +git submodule update --init --recursive sh tests/jit_build/build.sh "condition" $JIT_CONDITION diff --git a/tests/test_extension_import.py b/tests/test_extension_import.py index e5fc8ebfd..649830d54 100644 --- a/tests/test_extension_import.py +++ b/tests/test_extension_import.py @@ -38,7 +38,7 @@ def get_extensions_list_from_setup(self): """ This method reads setup.py and gets the list of extensions from the setup.py file """ - + #get setup.py file contents setup_path = os.path.join(self.parent_folder_path, "setup.py") @@ -58,14 +58,14 @@ def get_extensions_list_from_setup(self): if found == 1: continue #print ("extension", line, line_index) - + if "name"in line: name_line = line.strip() else: #get the next line line_index += 1 name_line = setup_contents[line_index].strip() - + #extract the name part if "name" in name_line: if "'" in name_line: @@ -74,7 +74,7 @@ def get_extensions_list_from_setup(self): name = name_line[name_line.find("name") + 6 : name_line.rfind('"')] extensions.append(name) - line_index += 1 + line_index += 1 return extensions @@ -132,7 +132,7 @@ def get_environment(self): else: env['LD_LIBRARY_PATH'] = ':'.join(extra_paths) return env - + def check_extension_import(self, extension_name, env): """ @@ -140,10 +140,10 @@ def check_extension_import(self, extension_name, env): Returns True if import successful, False if ImportError occurs """ try: - + # Run Python subprocess to test the import result = subprocess.run([ - sys.executable, '-c', + sys.executable, '-c', 'import ' + extension_name ], capture_output=True, text=True, timeout=30, env=env) print ("result.stdout", result.stdout, result.stderr) @@ -152,7 +152,7 @@ def check_extension_import(self, extension_name, env): return False, result.stderr else: return True, "" - + except subprocess.TimeoutExpired: print(f"Import test timed out for {extension_name}") return False, "Timeout" @@ -172,8 +172,8 @@ def check_jit_extension_import(self, extension_name, env): try: # Run Python subprocess to test the import result = subprocess.run([ - sys.executable, '-c', - 'from apex.op_builder import ' + builder_name + + sys.executable, '-c', + 'from apex.op_builder import ' + builder_name + '\n' + builder_name + "().load()" ], capture_output=True, text=True, timeout=timeout, env=env) print ("result.stdout", result.stdout, result.stderr) @@ -182,7 +182,7 @@ def check_jit_extension_import(self, extension_name, env): return False, result.stderr else: return True, "" - + except subprocess.TimeoutExpired: print(f"Import test timed out for {extension_name}") return False, "Timeout" From 7b4cafc3bc1b6e53f781479fb7c6b9828db5abba Mon Sep 17 00:00:00 2001 From: Jithun Nair Date: Mon, 20 Apr 2026 14:47:18 +0000 Subject: [PATCH 2/2] Reverting non-whitespace changes --- op_builder/builder.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index 3f3c32b15..2a573d5cf 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -586,7 +586,6 @@ def jit_load(self, verbose=True): if self.is_rocm_pytorch(): cxx_args.append("-D__HIP_PLATFORM_AMD__=1") - cxx_args.append("-DUSE_ROCM") os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) @@ -782,7 +781,6 @@ def nvcc_args(self): args += [ '-std=c++17', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__', '-U__HIP_NO_HALF2_OPERATORS__', - '-DUSE_ROCM', '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR ] @@ -926,4 +924,4 @@ def cxx_args(self): CUDA_ENABLE, ] - return args \ No newline at end of file + return args