diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4a2ea7412b..9b10a9c5a4 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -309,6 +309,16 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w py::object ln_out, py::handle quantizer, DType otype, const int sm_margin, const bool zero_centered_gamma); +/*************************************************************************************************** + * Memory allocation + **************************************************************************************************/ + +// Allocates tensors all backed by a single contiguous buffer. +std::vector bulk_allocate(const std::vector> &shapes, + const std::vector &dtypes, + std::optional device = std::nullopt, + std::optional> alignments = std::nullopt); + /*************************************************************************************************** * Cast **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/allocate.cpp b/transformer_engine/pytorch/csrc/extensions/allocate.cpp new file mode 100644 index 0000000000..8050dca045 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/allocate.cpp @@ -0,0 +1,86 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include "../extensions.h" + +namespace transformer_engine { +namespace pytorch { + +std::vector bulk_allocate(const std::vector> &shapes, + const std::vector &dtypes, + std::optional device, + std::optional> alignments) { + // Check shapes and dtypes + const size_t n = shapes.size(); + NVTE_CHECK(dtypes.size() == n, "Got ", shapes.size(), " shapes and ", dtypes.size(), " dtypes."); + NVTE_CHECK(!alignments || alignments->size() == n, "Got ", shapes.size(), " shapes and ", + alignments->size(), " alignments."); + + // Return immediately if no tensors are needed + if (n == 0) return {}; + + // Set defaults for optional arguments + if (!device) { + device = at::Device(at::kCUDA); + } + if (!alignments) { + alignments = std::vector{}; + alignments->reserve(n); + for (const auto &dtype : dtypes) { + alignments->push_back(c10::elementSize(dtype)); + } + } + + // Compute offsets in base buffer + std::vector byte_sizes(n); + std::vector offsets(n); + size_t base_byte_size = 0; + size_t base_alignment = 1; + for (size_t i = 0; i < n; ++i) { + byte_sizes[i] = product(shapes[i]) * at::elementSize(dtypes[i]); + offsets[i] = roundup(base_byte_size, (*alignments)[i]); + base_byte_size = offsets[i] + byte_sizes[i]; + base_alignment = std::max(base_alignment, (*alignments)[i]); + } + if (base_alignment > 1) { + // Pad in case data pointer is not aligned + base_byte_size += base_alignment; + } + + // Allocate base buffer + auto base_buffer = std::make_shared( + at::empty({static_cast(base_byte_size)}, at::device(*device).dtype(torch::kUInt8))); + uint8_t *base_ptr = base_buffer->data_ptr(); + base_ptr = + reinterpret_cast(roundup(reinterpret_cast(base_ptr), base_alignment)); + + // Create views into base buffer + std::vector out; + out.reserve(n); + std::vector shape_int64; + for (size_t i = 0; i < n; ++i) { + shape_int64.assign(shapes[i].begin(), shapes[i].end()); + if (byte_sizes[i] == 0) { + // Work around problems with from_blob when constructing an + // empty tensor. Passing a null pointer fails because it checks + // that the pointer is on GPU. Passing a non-null pointer can + // cause bugs in TE kernels. + out.emplace_back(at::empty(shape_int64, at::device(*device).dtype(dtypes[i]))); + } else { + // Construct tensor with custom deleter to keep base buffer alive + out.emplace_back(at::from_blob( + base_ptr + offsets[i], shape_int64, [base_buffer](void *) {}, + at::device(*device).dtype(dtypes[i]))); + } + } + return out; +} + +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5fb162c72d..a55c386758 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -511,60 +511,30 @@ std::tuple, std::vector> bulk_allocate_fp const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto is_2D_scaled = scaling_mode == NVTE_BLOCK_SCALING_2D; const auto fp8_dtype = quantizer_cpp_list[0]->dtype; - constexpr size_t fp8_elem_size = 1; - constexpr size_t scale_elem_size = 4; - - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); - bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { - return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); - } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); - }; // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; std::vector> rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; - } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + // Bulk-allocate data and scale tensors + std::vector> shapes = rowwise_data_shapes; + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), rowwise_scale_shapes.begin(), rowwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); - // Construct tensor views + // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_list.emplace_back( - make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - rowwise_scale_list.emplace_back( - make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kFloat32)); + rowwise_data_list.emplace_back(std::move(tensors[i])); + rowwise_scale_list.emplace_back(std::move(tensors[num_tensors + i])); } } @@ -572,7 +542,6 @@ std::tuple, std::vector> bulk_allocate_fp std::vector columnwise_data_list, columnwise_scale_list; std::vector> columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { columnwise_data_shapes.emplace_back(); auto &shape = columnwise_data_shapes.back(); @@ -584,30 +553,19 @@ std::tuple, std::vector> bulk_allocate_fp quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; - } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + // Bulk-allocate data and scale tensors + std::vector> shapes = columnwise_data_shapes; + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), columnwise_scale_shapes.begin(), columnwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); - // Construct tensor views + // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_list.emplace_back( - make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - columnwise_scale_list.emplace_back( - make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kFloat32)); + columnwise_data_list.push_back(tensors[i]); + columnwise_scale_list.push_back(tensors[num_tensors + i]); } } @@ -664,60 +622,29 @@ std::tuple, std::vector> bulk_allocate_mx const auto fp8_dtype = quantizer_cpp_list[0]->dtype; const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm; - constexpr size_t fp8_elem_size = 1; - constexpr size_t scale_elem_size = 1; - - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); - bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { - return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); - } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); - }; - // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; std::vector> rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; - } + // Bulk-allocate data and scale tensors + std::vector> shapes = rowwise_data_shapes; + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), rowwise_scale_shapes.begin(), rowwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views + // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_list.emplace_back( - make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - rowwise_scale_list.emplace_back( - make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + rowwise_data_list.emplace_back(std::move(tensors[i])); + rowwise_scale_list.emplace_back(std::move(tensors[num_tensors + i])); } } @@ -725,7 +652,6 @@ std::tuple, std::vector> bulk_allocate_mx std::vector columnwise_data_list, columnwise_scale_list; std::vector> columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { // For MXFP8, the columnwise data doesn't need transpose // because of TN, NT, NN layout support in SM100 @@ -734,30 +660,19 @@ std::tuple, std::vector> bulk_allocate_mx quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets; - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 256); // align to 256B - data_offsets.push_back(buffer_size); - buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size; - } - for (size_t i = 0; i < num_tensors; ++i) { - buffer_size = roundup(buffer_size, 16); // align to 16B - scale_offsets.push_back(buffer_size); - buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; - } + // Bulk-allocate data and scale tensors + std::vector> shapes = columnwise_data_shapes; + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), columnwise_scale_shapes.begin(), columnwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views + // Split data and scale tensors for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_list.emplace_back( - make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8)); - columnwise_scale_list.emplace_back( - make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + columnwise_data_list.push_back(tensors[i]); + columnwise_scale_list.push_back(tensors[num_tensors + i]); } } @@ -818,23 +733,6 @@ std::tuple, std::vector, bool> bulk_alloc const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; - constexpr size_t scale_elem_size = 1; - - // Helper function to construct tensor view - // Note: Deleter holds a shared_ptr for the buffer, so the buffer - // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, - size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); - bool is_empty_shape = product(shape) == 0; - if (buffer->data_ptr() == nullptr || is_empty_shape) { - return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); - } - return at::from_blob( - buffer->data_ptr() + offset, shape_int64, - [buffer](void *) {}, // deleter holds shared_ptr - at::device(at::kCUDA).dtype(dtype)); - }; // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) auto to_fp4_shape = [](const std::vector &shape) { @@ -849,54 +747,44 @@ std::tuple, std::vector, bool> bulk_alloc std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; std::vector> rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets, amax_offsets; + // Check whether data and scales can be packed in contiguous + // buffer. Amaxes are not contiguous since they are aligned to + // 16B. for (size_t i = 0; i < num_tensors; ++i) { - // FP4 data is aligned to 256B - const auto offset = roundup(buffer_size, 256); - if (offset != buffer_size) { + if (product(rowwise_data_shapes[i]) / 2 % 256 != 0) { contiguous_data_and_scale = false; } - data_offsets.push_back(offset); - buffer_size = offset + (product(rowwise_data_shapes[i]) + 1) / 2; - } - for (size_t i = 0; i < num_tensors; ++i) { - // Scales are aligned to 16B - const auto offset = roundup(buffer_size, 16); - if (offset != buffer_size) { + if (product(rowwise_scale_shapes[i]) % 16 != 0) { contiguous_data_and_scale = false; } - scale_offsets.push_back(offset); - buffer_size = offset + product(rowwise_scale_shapes[i]) * scale_elem_size; } + + // Bulk-allocate tensors data, scale, and amax tensors + std::vector> shapes; for (size_t i = 0; i < num_tensors; ++i) { - // Amaxes (FP32) are aligned to 16B - // Note: Multi-quantize kernel does not require contiguous amaxes. - const auto offset = roundup(buffer_size, 16); - amax_offsets.push_back(offset); - buffer_size = offset + 4; + shapes.emplace_back(to_fp4_shape(rowwise_data_shapes[i])); } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), rowwise_scale_shapes.begin(), rowwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); + alignments.insert(alignments.end(), num_tensors, 16); + shapes.insert(shapes.end(), num_tensors, std::vector{1}); + dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); + + // Split data, scale, and amax tensors for (size_t i = 0; i < num_tensors; ++i) { - rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), - data_offsets[i], torch::kUInt8)); - rowwise_scale_list.emplace_back( - make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); - amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + rowwise_data_list.push_back(tensors[i]); + rowwise_scale_list.push_back(tensors[num_tensors + i]); + amax_rowwise_list.push_back(tensors[2 * num_tensors + i]); } } @@ -904,7 +792,6 @@ std::tuple, std::vector, bool> bulk_alloc std::vector columnwise_data_list, columnwise_scale_list, amax_columnwise_list; std::vector> columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { - // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { // push the transposed shape into NVFP4 columnwise shape // NVFP4 on SM100 is TN only @@ -918,47 +805,38 @@ std::tuple, std::vector, bool> bulk_alloc quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); } - // Offsets in full buffer - size_t buffer_size = 0; - std::vector data_offsets, scale_offsets, amax_offsets; + // Check whether data and scales can be packed in contiguous + // buffer. Amaxes are not contiguous since they are aligned to + // 16B. for (size_t i = 0; i < num_tensors; ++i) { - // FP4 data is aligned to 256B - const auto offset = roundup(buffer_size, 256); - if (offset != buffer_size) { + if (product(columnwise_data_shapes[i]) / 2 % 256 != 0) { contiguous_data_and_scale = false; } - data_offsets.push_back(offset); - buffer_size = offset + (product(columnwise_data_shapes[i]) + 1) / 2; - } - for (size_t i = 0; i < num_tensors; ++i) { - // Scales are aligned to 16B - const auto offset = roundup(buffer_size, 16); - if (offset != buffer_size) { + if (product(columnwise_scale_shapes[i]) % 16 != 0) { contiguous_data_and_scale = false; } - scale_offsets.push_back(offset); - buffer_size = offset + product(columnwise_scale_shapes[i]) * scale_elem_size; } + + // Bulk-allocate tensors data, scale, and amax tensors + std::vector> shapes; for (size_t i = 0; i < num_tensors; ++i) { - // Amaxes (FP32) are aligned to 16B - // Note: Multi-quantize kernel does not require contiguous amaxes. - const auto offset = roundup(buffer_size, 16); - amax_offsets.push_back(offset); - buffer_size = offset + 4; + shapes.emplace_back(to_fp4_shape(columnwise_data_shapes[i])); } - - // Allocate full buffer - auto buffer = std::make_shared( - at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - - // Construct tensor views + std::vector dtypes(num_tensors, torch::kUInt8); + std::vector alignments(num_tensors, 256); + shapes.insert(shapes.end(), columnwise_scale_shapes.begin(), columnwise_scale_shapes.end()); + dtypes.insert(dtypes.end(), num_tensors, torch::kUInt8); + alignments.insert(alignments.end(), num_tensors, 16); + shapes.insert(shapes.end(), num_tensors, std::vector{1}); + dtypes.insert(dtypes.end(), num_tensors, torch::kFloat32); + alignments.insert(alignments.end(), num_tensors, 16); + auto tensors = bulk_allocate(shapes, dtypes, std::nullopt, alignments); + + // Split data, scale, and amax tensors for (size_t i = 0; i < num_tensors; ++i) { - columnwise_data_list.emplace_back(make_torch_view( - buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8)); - columnwise_scale_list.emplace_back( - make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); - amax_columnwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + columnwise_data_list.push_back(tensors[i]); + columnwise_scale_list.push_back(tensors[num_tensors + i]); + amax_columnwise_list.push_back(tensors[2 * num_tensors + i]); } } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index eb7576d905..a813f3119d 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -352,6 +352,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"), py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), py::arg("out_dtype"), py::call_guard()); + // NVFP4 2D m.def("nvfp4_2d_compute_partial_amax", &transformer_engine::pytorch::nvfp4_2d_compute_partial_amax, @@ -404,6 +405,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "In-place swizzle of grouped tensor scales for GEMM", py::arg("tensor"), py::arg("rowwise"), py::arg("columnwise")); + // Tensor allocation + m.def("bulk_allocate", &transformer_engine::pytorch::bulk_allocate, + "Allocate tensors backed by a single contiguous buffer", py::arg("shapes"), + py::arg("dtypes"), py::arg("device") = py::none(), py::arg("alignments") = py::none(), + py::call_guard()); + // attention kernels m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, "Prepare QKV for Flash Attention", py::call_guard()); diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 720a274119..f56549cf6b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -484,10 +484,13 @@ def backward( if ctx.fuse_wgrad_accumulation: wgrad_list = main_grads else: - wgrad_list = [ - torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) - for w in weights - ] + weight_shape = list(weights[0].size()) + wgrad_list = tex.bulk_allocate( + [weight_shape] * ctx.num_gemms, + [ctx.activation_dtype] * ctx.num_gemms, + ctx.device, + [256] * ctx.num_gemms, # alignment + ) if ctx.save_original_input: inp = inputmats[0] diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index fe5997a71e..9a66527139 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -956,13 +956,13 @@ def fuser_backward( self.weight0, "overwrite_main_grad", False ) else: - weight_shape = (self.out_features, self.in_features) - for group_idx in range(num_groups): - grad_weights[group_idx] = torch.empty( - weight_shape, - dtype=ctx.dtype, - device=device, - ) + weight_shape = [self.out_features, self.in_features] + grad_weights = tex.bulk_allocate( + [weight_shape] * num_groups, + [ctx.dtype] * num_groups, + device, + [256] * num_groups, # alignment + ) else: accumulate_into_main_grad = False diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index aca49e9866..5e6ae502a5 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -201,8 +201,12 @@ def _compute_grad_params( w_list[idx] = wp.main_grad accumulate_into_main_grad = not getattr(fc_op.weight0, "overwrite_main_grad", False) else: - for idx in range(num_groups): - w_list[idx] = torch.empty(weight_shape, dtype=dtype, device=device) + w_list = tex.bulk_allocate( + [list(weight_shape)] * num_groups, + [dtype] * num_groups, + device, + [256] * num_groups, # alignment + ) wgrad_output = w_list if ctx.weight_requires_grad: