From bde08835b5911fefddeac1c0e57bd7121529b3bb Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 18 Apr 2026 01:19:31 +0000 Subject: [PATCH 1/7] [PyTorch] Add bulk_allocate utility and use it in quantized tensor allocators Introduces transformer_engine/pytorch/csrc/extensions/allocate.cpp with a general-purpose bulk_allocate function: given parallel lists of shapes, dtypes, and per-tensor byte alignments, it computes a packed layout, does a single CUDA allocation, and returns at::from_blob views whose deleters keep the backing buffer alive. The three internal bulk_allocate_*_tensors helpers in cast.cpp are refactored to call bulk_allocate instead of each owning a copy of the make_torch_view lambda and the offset-computation loops (~120 lines removed). The new function is also exposed via pybind11 so Python can allocate packed CUDA buffers directly without going through a quantizer. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/extensions.h | 9 + .../pytorch/csrc/extensions/allocate.cpp | 64 ++++ .../pytorch/csrc/extensions/cast.cpp | 302 ++++++------------ .../pytorch/csrc/extensions/pybind.cpp | 6 + 4 files changed, 169 insertions(+), 212 deletions(-) create mode 100644 transformer_engine/pytorch/csrc/extensions/allocate.cpp diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index fb5783dfcb..c1890a26c6 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -297,6 +297,15 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w py::object ln_out, py::handle quantizer, DType otype, const int sm_margin, const bool zero_centered_gamma); +/*************************************************************************************************** + * Memory allocation + **************************************************************************************************/ + +// Allocates tensors all backed by a single contiguous buffer. +std::vector bulk_allocate(const std::vector> &shapes, + const std::vector &dtypes, + const std::vector &alignments); + /*************************************************************************************************** * Cast **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/allocate.cpp b/transformer_engine/pytorch/csrc/extensions/allocate.cpp new file mode 100644 index 0000000000..9b588abbdd --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/allocate.cpp @@ -0,0 +1,64 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include "../extensions.h" + +namespace transformer_engine { +namespace pytorch { + +std::vector bulk_allocate(const std::vector> &shapes, + const std::vector &dtypes, + const std::vector &alignments) { + const size_t n = shapes.size(); + NVTE_CHECK(dtypes.size() == n, "Got ", shapes.size(), " shapes and ", dtypes.size(), " dtypes."); + NVTE_CHECK(alignments.size() == n, "Got ", shapes.size(), " shapes and ", + alignments.size(), " alignments."); + if (n == 0) return {}; + + // Compute per-tensor sizes and offsets + size_t total_bytes = 0; + std::vector byte_sizes(n); + std::vector offsets(n); + for (size_t i = 0; i < n; ++i) { + total_bytes = roundup(total_bytes, alignments[i]); + offsets[i] = total_bytes; + byte_sizes[i] = product(shapes[i]) * at::elementSize(dtypes[i]); + total_bytes += byte_sizes[i]; + } + + // Single backing allocation + auto buffer = std::make_shared( + at::empty({static_cast(total_bytes)}, at::device(at::kCUDA).dtype(torch::kUInt8))); + uint8_t *data_ptr = buffer->data_ptr(); + + // Create views into the buffer + std::vector out; + out.reserve(n); + std::vector shape_int64; + for (size_t i = 0; i < n; ++i) { + shape_int64.assign(shapes[i].begin(), shapes[i].end()); + if (byte_sizes[i] == 0) { + // Work around problems with constructing an empty tensor with + // from_blob. Passing a null pointer fails because it checks + // that the pointer is on GPU. Passing a non-null pointer can + // cause bugs in TE kernels. + out.emplace_back(at::empty(shape_int64, at::device(at::kCUDA).dtype(dtypes[i]))); + } else { + out.emplace_back(at::from_blob( + data_ptr + offsets[i], + shape_int64, + [buffer](void *) {}, // Deleter keeps buffer alive + at::device(at::kCUDA).dtype(dtypes[i]))); + } + } + return out; +} + +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5fb162c72d..24924b4987 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -511,60 +511,30 @@ std::tuple, std::vector> bulk_allocate_fp const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto is_2D_scaled = scaling_mode == NVTE_BLOCK_SCALING_2D; const auto fp8_dtype = quantizer_cpp_list[0]->dtype; - constexpr size_t fp8_elem_size = 1; - constexpr size_t scale_elem_size = 4; - - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); - bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { - return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); - } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); - }; // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; std::vector> rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; - } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + // Bulk-allocate data and scale tensors + std::vector> shapes = rowwise_data_shapes; + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), rowwise_scale_shapes.begin(), rowwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, alignments); - // Construct tensor views + // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_list.emplace_back( - make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - rowwise_scale_list.emplace_back( - make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kFloat32)); + rowwise_data_list.emplace_back(std::move(tensors[i])); + rowwise_scale_list.emplace_back(std::move(tensors[num_tensors + i])); } } @@ -572,7 +542,6 @@ std::tuple, std::vector> bulk_allocate_fp std::vector columnwise_data_list, columnwise_scale_list; std::vector> columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { columnwise_data_shapes.emplace_back(); auto &shape = columnwise_data_shapes.back(); @@ -584,30 +553,19 @@ std::tuple, std::vector> bulk_allocate_fp quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; - } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + // Bulk-allocate data and scale tensors + std::vector> shapes = columnwise_data_shapes; + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), columnwise_scale_shapes.begin(), columnwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, alignments); - // Construct tensor views + // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_list.emplace_back( - make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - columnwise_scale_list.emplace_back( - make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kFloat32)); + columnwise_data_list.push_back(tensors[i]); + columnwise_scale_list.push_back(tensors[num_tensors + i]); } } @@ -664,60 +622,29 @@ std::tuple, std::vector> bulk_allocate_mx const auto fp8_dtype = quantizer_cpp_list[0]->dtype; const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm; - constexpr size_t fp8_elem_size = 1; - constexpr size_t scale_elem_size = 1; - - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); - bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { - return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); - } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); - }; - // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; std::vector> rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; - } + // Bulk-allocate data and scale tensors + std::vector> shapes = rowwise_data_shapes; + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), rowwise_scale_shapes.begin(), rowwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, alignments); - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views + // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_list.emplace_back( - make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - rowwise_scale_list.emplace_back( - make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + rowwise_data_list.emplace_back(std::move(tensors[i])); + rowwise_scale_list.emplace_back(std::move(tensors[num_tensors + i])); } } @@ -725,7 +652,6 @@ std::tuple, std::vector> bulk_allocate_mx std::vector columnwise_data_list, columnwise_scale_list; std::vector> columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { // For MXFP8, the columnwise data doesn't need transpose // because of TN, NT, NN layout support in SM100 @@ -734,30 +660,19 @@ std::tuple, std::vector> bulk_allocate_mx quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; - } + // Bulk-allocate data and scale tensors + std::vector> shapes = columnwise_data_shapes; + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), columnwise_scale_shapes.begin(), columnwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, alignments); - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views + // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_list.emplace_back( - make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - columnwise_scale_list.emplace_back( - make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + columnwise_data_list.push_back(tensors[i]); + columnwise_scale_list.push_back(tensors[num_tensors + i]); } } @@ -818,23 +733,6 @@ std::tuple, std::vector, bool> bulk_alloc const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; - constexpr size_t scale_elem_size = 1; - - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); - bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { - return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); - } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); - }; // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) auto to_fp4_shape = [](const std::vector &shape) { @@ -849,54 +747,44 @@ std::tuple, std::vector, bool> bulk_alloc std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; std::vector> rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets, amax_offsets; + // Check whether data and scales can be packed in contiguous + // buffer. Amaxes are not contiguous since they are aligned to + // 16B. for (size_t i = 0; i < num_tensors; ++i) { - // FP4 data is aligned to 256B - const auto offset = roundup(buffer_size, 256); - if (offset != buffer_size) { + if (product(rowwise_data_shapes[i]) / 2 % 256 != 0) { contiguous_data_and_scale = false; } - data_offsets.push_back(offset); - buffer_size = offset + (product(rowwise_data_shapes[i]) + 1) / 2; - } - for (size_t i = 0; i < num_tensors; ++i) { - // Scales are aligned to 16B - const auto offset = roundup(buffer_size, 16); - if (offset != buffer_size) { + if (product(rowwise_scale_shapes[i]) % 16 != 0) { contiguous_data_and_scale = false; } - scale_offsets.push_back(offset); - buffer_size = offset + product(rowwise_scale_shapes[i]) * scale_elem_size; } + + // Bulk-allocate tensors data, scale, and amax tensors + std::vector> shapes; for (size_t i = 0; i < num_tensors; ++i) { - // Amaxes (FP32) are aligned to 16B - // Note: Multi-quantize kernel does not require contiguous amaxes. - const auto offset = roundup(buffer_size, 16); - amax_offsets.push_back(offset); - buffer_size = offset + 4; + shapes.emplace_back(to_fp4_shape(rowwise_data_shapes[i])); } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), rowwise_scale_shapes.begin(), rowwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); + alignments.insert(alignments.end(), num_tensors, 16); + shapes.insert(shapes.end(), num_tensors, std::vector{1}); + dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, alignments); + + // Split data, scale, and amax tensors for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), - data_offsets[i], torch::kUInt8)); - rowwise_scale_list.emplace_back( - make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); - amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + rowwise_data_list.push_back(tensors[i]); + rowwise_scale_list.push_back(tensors[num_tensors + i]); + amax_rowwise_list.push_back(tensors[2 * num_tensors + i]); } } @@ -904,7 +792,6 @@ std::tuple, std::vector, bool> bulk_alloc std::vector columnwise_data_list, columnwise_scale_list, amax_columnwise_list; std::vector> columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { // push the transposed shape into NVFP4 columnwise shape // NVFP4 on SM100 is TN only @@ -918,47 +805,38 @@ std::tuple, std::vector, bool> bulk_alloc quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets, amax_offsets; + // Check whether data and scales can be packed in contiguous + // buffer. Amaxes are not contiguous since they are aligned to + // 16B. for (size_t i = 0; i < num_tensors; ++i) { - // FP4 data is aligned to 256B - const auto offset = roundup(buffer_size, 256); - if (offset != buffer_size) { + if (product(columnwise_data_shapes[i]) / 2 % 256 != 0) { contiguous_data_and_scale = false; } - data_offsets.push_back(offset); - buffer_size = offset + (product(columnwise_data_shapes[i]) + 1) / 2; - } - for (size_t i = 0; i < num_tensors; ++i) { - // Scales are aligned to 16B - const auto offset = roundup(buffer_size, 16); - if (offset != buffer_size) { + if (product(columnwise_scale_shapes[i]) % 16 != 0) { contiguous_data_and_scale = false; } - scale_offsets.push_back(offset); - buffer_size = offset + product(columnwise_scale_shapes[i]) * scale_elem_size; } + + // Bulk-allocate tensors data, scale, and amax tensors + std::vector> shapes; for (size_t i = 0; i < num_tensors; ++i) { - // Amaxes (FP32) are aligned to 16B - // Note: Multi-quantize kernel does not require contiguous amaxes. - const auto offset = roundup(buffer_size, 16); - amax_offsets.push_back(offset); - buffer_size = offset + 4; + shapes.emplace_back(to_fp4_shape(columnwise_data_shapes[i])); } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), columnwise_scale_shapes.begin(), columnwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); + alignments.insert(alignments.end(), num_tensors, 16); + shapes.insert(shapes.end(), num_tensors, std::vector{1}); + dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, alignments); + + // Split data, scale, and amax tensors for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_list.emplace_back(make_torch_view( - buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8)); - columnwise_scale_list.emplace_back( - make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); - amax_columnwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + columnwise_data_list.push_back(tensors[i]); + columnwise_scale_list.push_back(tensors[num_tensors + i]); + amax_columnwise_list.push_back(tensors[2 * num_tensors + i]); } } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 27d26d3dab..e51b3b0dfc 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -352,6 +352,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"), py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), py::arg("out_dtype"), py::call_guard()); + // NVFP4 2D m.def("nvfp4_2d_compute_partial_amax", &transformer_engine::pytorch::nvfp4_2d_compute_partial_amax, @@ -395,6 +396,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "In-place swizzle of grouped tensor scales for GEMM", py::arg("tensor"), py::arg("rowwise"), py::arg("columnwise")); + // Tensor allocation + m.def("bulk_allocate", &transformer_engine::pytorch::bulk_allocate, + "Allocate tensors backed by a single contiguous buffer", + py::arg("shapes"), py::arg("dtypes"), py::arg("alignments")); + // attention kernels m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, "Prepare QKV for Flash Attention", py::call_guard()); From 81597a6b62556bd0cc208ef1d544809a57275bfb Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 18 Apr 2026 02:17:31 +0000 Subject: [PATCH 2/7] Bulk-allocate wgrads in grouped linear impls Signed-off-by: Tim Moon --- .../pytorch/csrc/extensions/allocate.cpp | 8 +++++--- transformer_engine/pytorch/module/grouped_linear.py | 10 ++++++---- .../pytorch/ops/basic/grouped_linear.py | 13 ++++++------- .../pytorch/ops/fused/backward_grouped_mlp.py | 7 +++++-- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/allocate.cpp b/transformer_engine/pytorch/csrc/extensions/allocate.cpp index 9b588abbdd..66a6be61e1 100644 --- a/transformer_engine/pytorch/csrc/extensions/allocate.cpp +++ b/transformer_engine/pytorch/csrc/extensions/allocate.cpp @@ -26,7 +26,9 @@ std::vector bulk_allocate(const std::vector> &sh std::vector byte_sizes(n); std::vector offsets(n); for (size_t i = 0; i < n; ++i) { - total_bytes = roundup(total_bytes, alignments[i]); + if (alignments[i] > 0) { + total_bytes = roundup(total_bytes, alignments[i]); + } offsets[i] = total_bytes; byte_sizes[i] = product(shapes[i]) * at::elementSize(dtypes[i]); total_bytes += byte_sizes[i]; @@ -44,8 +46,8 @@ std::vector bulk_allocate(const std::vector> &sh for (size_t i = 0; i < n; ++i) { shape_int64.assign(shapes[i].begin(), shapes[i].end()); if (byte_sizes[i] == 0) { - // Work around problems with constructing an empty tensor with - // from_blob. Passing a null pointer fails because it checks + // Work around problems with from_blob when constructing an + // empty tensor. Passing a null pointer fails because it checks // that the pointer is on GPU. Passing a non-null pointer can // cause bugs in TE kernels. out.emplace_back(at::empty(shape_int64, at::device(at::kCUDA).dtype(dtypes[i]))); diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 720a274119..b85317ea75 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -484,10 +484,12 @@ def backward( if ctx.fuse_wgrad_accumulation: wgrad_list = main_grads else: - wgrad_list = [ - torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) - for w in weights - ] + weight_shape = list(weights[0].size()) + wgrad_list = tex.bulk_allocate( + [weight_shape] * ctx.num_gemms, + [ctx.activation_dtype] * ctx.num_gemms, + [0] * ctx.num_gemms, # alignment + ) if ctx.save_original_input: inp = inputmats[0] diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index a1d40a30ec..a5ccbe49f3 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -955,13 +955,12 @@ def fuser_backward( self.weight0, "overwrite_main_grad", False ) else: - weight_shape = (self.out_features, self.in_features) - for group_idx in range(num_groups): - grad_weights[group_idx] = torch.empty( - weight_shape, - dtype=ctx.dtype, - device=device, - ) + weight_shape = [self.out_features, self.in_features] + grad_weights = tex.bulk_allocate( + [weight_shape] * num_groups, + [ctx.dtype] * num_groups, + [0] * num_groups, # alignment + ) else: accumulate_into_main_grad = False diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 3eb57c3563..2c729ebe8a 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -199,8 +199,11 @@ def _compute_grad_params( w_list[idx] = wp.main_grad accumulate_into_main_grad = not getattr(fc_op.weight0, "overwrite_main_grad", False) else: - for idx in range(num_groups): - w_list[idx] = torch.empty(weight_shape, dtype=dtype, device=device) + w_list = tex.bulk_allocate( + [list(weight_shape)] * num_groups, + [dtype] * num_groups, + [0] * num_groups, # alignment + ) wgrad_output = w_list if ctx.weight_requires_grad: From 84096c40ea2e584c0fcd79fef8b2d7e43fec57c5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 18 Apr 2026 02:33:17 +0000 Subject: [PATCH 3/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions/allocate.cpp | 8 +++----- transformer_engine/pytorch/csrc/extensions/pybind.cpp | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/allocate.cpp b/transformer_engine/pytorch/csrc/extensions/allocate.cpp index 66a6be61e1..375a4a6332 100644 --- a/transformer_engine/pytorch/csrc/extensions/allocate.cpp +++ b/transformer_engine/pytorch/csrc/extensions/allocate.cpp @@ -17,8 +17,8 @@ std::vector bulk_allocate(const std::vector> &sh const std::vector &alignments) { const size_t n = shapes.size(); NVTE_CHECK(dtypes.size() == n, "Got ", shapes.size(), " shapes and ", dtypes.size(), " dtypes."); - NVTE_CHECK(alignments.size() == n, "Got ", shapes.size(), " shapes and ", - alignments.size(), " alignments."); + NVTE_CHECK(alignments.size() == n, "Got ", shapes.size(), " shapes and ", alignments.size(), + " alignments."); if (n == 0) return {}; // Compute per-tensor sizes and offsets @@ -53,9 +53,7 @@ std::vector bulk_allocate(const std::vector> &sh out.emplace_back(at::empty(shape_int64, at::device(at::kCUDA).dtype(dtypes[i]))); } else { out.emplace_back(at::from_blob( - data_ptr + offsets[i], - shape_int64, - [buffer](void *) {}, // Deleter keeps buffer alive + data_ptr + offsets[i], shape_int64, [buffer](void *) {}, // Deleter keeps buffer alive at::device(at::kCUDA).dtype(dtypes[i]))); } } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e51b3b0dfc..2efa08e381 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -398,8 +398,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Tensor allocation m.def("bulk_allocate", &transformer_engine::pytorch::bulk_allocate, - "Allocate tensors backed by a single contiguous buffer", - py::arg("shapes"), py::arg("dtypes"), py::arg("alignments")); + "Allocate tensors backed by a single contiguous buffer", py::arg("shapes"), + py::arg("dtypes"), py::arg("alignments")); // attention kernels m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, From d4ba30b8876a45efa3e31d139266f0f27fdc7a38 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 22 Apr 2026 23:14:22 +0000 Subject: [PATCH 4/7] Apply review suggestions Make optional args for device and alignment. Handle case where base data_ptr is unaligned. Align grouped linear wgrad buffers to 256B. Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/extensions.h | 3 +- .../pytorch/csrc/extensions/allocate.cpp | 55 +++++++++++++------ .../pytorch/csrc/extensions/cast.cpp | 12 ++-- .../pytorch/csrc/extensions/pybind.cpp | 5 +- .../pytorch/module/grouped_linear.py | 3 +- .../pytorch/ops/basic/grouped_linear.py | 3 +- 6 files changed, 54 insertions(+), 27 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index c1890a26c6..54a9772c6b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -304,7 +304,8 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Allocates tensors all backed by a single contiguous buffer. std::vector bulk_allocate(const std::vector> &shapes, const std::vector &dtypes, - const std::vector &alignments); + std::optional device = std::nullopt, + std::optional> alignments = std::nullopt); /*************************************************************************************************** * Cast diff --git a/transformer_engine/pytorch/csrc/extensions/allocate.cpp b/transformer_engine/pytorch/csrc/extensions/allocate.cpp index 375a4a6332..def0187974 100644 --- a/transformer_engine/pytorch/csrc/extensions/allocate.cpp +++ b/transformer_engine/pytorch/csrc/extensions/allocate.cpp @@ -14,30 +14,51 @@ namespace pytorch { std::vector bulk_allocate(const std::vector> &shapes, const std::vector &dtypes, - const std::vector &alignments) { + std::optional device, + std::optional> alignments) { + // Check shapes and dtypes const size_t n = shapes.size(); NVTE_CHECK(dtypes.size() == n, "Got ", shapes.size(), " shapes and ", dtypes.size(), " dtypes."); - NVTE_CHECK(alignments.size() == n, "Got ", shapes.size(), " shapes and ", alignments.size(), - " alignments."); + NVTE_CHECK(!alignments || alignments->size() == n, + "Got ", shapes.size(), " shapes and ", alignments->size(), " alignments."); + + // Return immediately if no tensors are needed if (n == 0) return {}; - // Compute per-tensor sizes and offsets - size_t total_bytes = 0; + // Set defaults for optional arguments + if (!device) { + device = at::Device(at::kCUDA); + } + if (!alignments) { + alignments = std::vector{}; + alignments->reserve(n); + for (const auto &dtype : dtypes) { + alignments->push_back(c10::elementSize(dtype)); + } + } + + // Compute offsets in base buffer std::vector byte_sizes(n); std::vector offsets(n); + size_t base_byte_size = 0; + size_t base_alignment = 1; for (size_t i = 0; i < n; ++i) { - if (alignments[i] > 0) { - total_bytes = roundup(total_bytes, alignments[i]); - } - offsets[i] = total_bytes; byte_sizes[i] = product(shapes[i]) * at::elementSize(dtypes[i]); - total_bytes += byte_sizes[i]; + offsets[i] = roundup(base_byte_size, (*alignments)[i]); + base_byte_size = offsets[i] + byte_sizes[i]; + base_alignment = std::max(base_alignment, (*alignments)[i]); + } + if (base_alignment > 1) { + // Pad in case data pointer is not aligned + base_byte_size += base_alignment; } - // Single backing allocation - auto buffer = std::make_shared( - at::empty({static_cast(total_bytes)}, at::device(at::kCUDA).dtype(torch::kUInt8))); - uint8_t *data_ptr = buffer->data_ptr(); + // Allocate base buffer + auto base_buffer = std::make_shared( + at::empty({static_cast(base_byte_size)}, at::device(*device).dtype(torch::kUInt8))); + uint8_t *base_ptr = base_buffer->data_ptr(); + base_ptr = reinterpret_cast(roundup(reinterpret_cast(base_ptr), + base_alignment)); // Create views into the buffer std::vector out; @@ -50,11 +71,11 @@ std::vector bulk_allocate(const std::vector> &sh // empty tensor. Passing a null pointer fails because it checks // that the pointer is on GPU. Passing a non-null pointer can // cause bugs in TE kernels. - out.emplace_back(at::empty(shape_int64, at::device(at::kCUDA).dtype(dtypes[i]))); + out.emplace_back(at::empty(shape_int64, at::device(*device).dtype(dtypes[i]))); } else { out.emplace_back(at::from_blob( - data_ptr + offsets[i], shape_int64, [buffer](void *) {}, // Deleter keeps buffer alive - at::device(at::kCUDA).dtype(dtypes[i]))); + base_ptr + offsets[i], shape_int64, [base_buffer](void *) {}, // Deleter keeps buffer alive + at::device(*device).dtype(dtypes[i]))); } } return out; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 24924b4987..a55c386758 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -529,7 +529,7 @@ std::tuple, std::vector> bulk_allocate_fp shapes.insert(shapes.end(), rowwise_scale_shapes.begin(), rowwise_scale_shapes.end()); dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); alignments.insert(alignments.end(), num_tensors, 16); - auto tensors = bulk_allocate(shapes, dtypes, alignments); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { @@ -560,7 +560,7 @@ std::tuple, std::vector> bulk_allocate_fp shapes.insert(shapes.end(), columnwise_scale_shapes.begin(), columnwise_scale_shapes.end()); dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); alignments.insert(alignments.end(), num_tensors, 16); - auto tensors = bulk_allocate(shapes, dtypes, alignments); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { @@ -639,7 +639,7 @@ std::tuple, std::vector> bulk_allocate_mx shapes.insert(shapes.end(), rowwise_scale_shapes.begin(), rowwise_scale_shapes.end()); dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); alignments.insert(alignments.end(), num_tensors, 16); - auto tensors = bulk_allocate(shapes, dtypes, alignments); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { @@ -667,7 +667,7 @@ std::tuple, std::vector> bulk_allocate_mx shapes.insert(shapes.end(), columnwise_scale_shapes.begin(), columnwise_scale_shapes.end()); dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); alignments.insert(alignments.end(), num_tensors, 16); - auto tensors = bulk_allocate(shapes, dtypes, alignments); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { @@ -778,7 +778,7 @@ std::tuple, std::vector, bool> bulk_alloc shapes.insert(shapes.end(), num_tensors, std::vector{1}); dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); alignments.insert(alignments.end(), num_tensors, 16); - auto tensors = bulk_allocate(shapes, dtypes, alignments); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); // Split data, scale, and amax tensors for (size_t i = 0; i < num_tensors; ++i) { @@ -830,7 +830,7 @@ std::tuple, std::vector, bool> bulk_alloc shapes.insert(shapes.end(), num_tensors, std::vector{1}); dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); alignments.insert(alignments.end(), num_tensors, 16); - auto tensors = bulk_allocate(shapes, dtypes, alignments); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); // Split data, scale, and amax tensors for (size_t i = 0; i < num_tensors; ++i) { diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 2efa08e381..19aa44fe19 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -399,7 +399,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Tensor allocation m.def("bulk_allocate", &transformer_engine::pytorch::bulk_allocate, "Allocate tensors backed by a single contiguous buffer", py::arg("shapes"), - py::arg("dtypes"), py::arg("alignments")); + py::arg("dtypes"), + py::arg("device") = py::none(), + py::arg("alignments") = py::none(), + py::call_guard()); // attention kernels m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b85317ea75..f56549cf6b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -488,7 +488,8 @@ def backward( wgrad_list = tex.bulk_allocate( [weight_shape] * ctx.num_gemms, [ctx.activation_dtype] * ctx.num_gemms, - [0] * ctx.num_gemms, # alignment + ctx.device, + [256] * ctx.num_gemms, # alignment ) if ctx.save_original_input: diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index a5ccbe49f3..8783a68f9d 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -959,7 +959,8 @@ def fuser_backward( grad_weights = tex.bulk_allocate( [weight_shape] * num_groups, [ctx.dtype] * num_groups, - [0] * num_groups, # alignment + device, + [256] * num_groups, # alignment ) else: accumulate_into_main_grad = False From f73c4f8653f5e5b0d2104385a596e138e1dccc8b Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 22 Apr 2026 23:27:50 +0000 Subject: [PATCH 5/7] Nits from Claude Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/extensions/allocate.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/allocate.cpp b/transformer_engine/pytorch/csrc/extensions/allocate.cpp index def0187974..2a60807dbb 100644 --- a/transformer_engine/pytorch/csrc/extensions/allocate.cpp +++ b/transformer_engine/pytorch/csrc/extensions/allocate.cpp @@ -60,7 +60,7 @@ std::vector bulk_allocate(const std::vector> &sh base_ptr = reinterpret_cast(roundup(reinterpret_cast(base_ptr), base_alignment)); - // Create views into the buffer + // Create views into base buffer std::vector out; out.reserve(n); std::vector shape_int64; @@ -73,8 +73,9 @@ std::vector bulk_allocate(const std::vector> &sh // cause bugs in TE kernels. out.emplace_back(at::empty(shape_int64, at::device(*device).dtype(dtypes[i]))); } else { + // Construct tensor with custom deleter to keep base buffer alive out.emplace_back(at::from_blob( - base_ptr + offsets[i], shape_int64, [base_buffer](void *) {}, // Deleter keeps buffer alive + base_ptr + offsets[i], shape_int64, [base_buffer](void *) {}, at::device(*device).dtype(dtypes[i]))); } } From e25ba48ab6458753b73650e06b3ba30950dfbc60 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Apr 2026 23:29:19 +0000 Subject: [PATCH 6/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions/allocate.cpp | 8 ++++---- transformer_engine/pytorch/csrc/extensions/pybind.cpp | 4 +--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/allocate.cpp b/transformer_engine/pytorch/csrc/extensions/allocate.cpp index 2a60807dbb..8050dca045 100644 --- a/transformer_engine/pytorch/csrc/extensions/allocate.cpp +++ b/transformer_engine/pytorch/csrc/extensions/allocate.cpp @@ -19,8 +19,8 @@ std::vector bulk_allocate(const std::vector> &sh // Check shapes and dtypes const size_t n = shapes.size(); NVTE_CHECK(dtypes.size() == n, "Got ", shapes.size(), " shapes and ", dtypes.size(), " dtypes."); - NVTE_CHECK(!alignments || alignments->size() == n, - "Got ", shapes.size(), " shapes and ", alignments->size(), " alignments."); + NVTE_CHECK(!alignments || alignments->size() == n, "Got ", shapes.size(), " shapes and ", + alignments->size(), " alignments."); // Return immediately if no tensors are needed if (n == 0) return {}; @@ -57,8 +57,8 @@ std::vector bulk_allocate(const std::vector> &sh auto base_buffer = std::make_shared( at::empty({static_cast(base_byte_size)}, at::device(*device).dtype(torch::kUInt8))); uint8_t *base_ptr = base_buffer->data_ptr(); - base_ptr = reinterpret_cast(roundup(reinterpret_cast(base_ptr), - base_alignment)); + base_ptr = + reinterpret_cast(roundup(reinterpret_cast(base_ptr), base_alignment)); // Create views into base buffer std::vector out; diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 133522a7ef..a813f3119d 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -408,9 +408,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Tensor allocation m.def("bulk_allocate", &transformer_engine::pytorch::bulk_allocate, "Allocate tensors backed by a single contiguous buffer", py::arg("shapes"), - py::arg("dtypes"), - py::arg("device") = py::none(), - py::arg("alignments") = py::none(), + py::arg("dtypes"), py::arg("device") = py::none(), py::arg("alignments") = py::none(), py::call_guard()); // attention kernels From 16806b40b49fa1bf8e980a0c01c502760056d436 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 22 Apr 2026 23:48:40 +0000 Subject: [PATCH 7/7] Fix incorrect call to `bulk_allocate` Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index e62589946d..5e6ae502a5 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -204,7 +204,8 @@ def _compute_grad_params( w_list = tex.bulk_allocate( [list(weight_shape)] * num_groups, [dtype] * num_groups, - [0] * num_groups, # alignment + device, + [256] * num_groups, # alignment ) wgrad_output = w_list