Skip to content
10 changes: 10 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,16 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
py::object ln_out, py::handle quantizer, DType otype,
const int sm_margin, const bool zero_centered_gamma);

/***************************************************************************************************
* Memory allocation
**************************************************************************************************/

// Allocates tensors all backed by a single contiguous buffer.
std::vector<at::Tensor> bulk_allocate(const std::vector<std::vector<size_t>> &shapes,
const std::vector<at::ScalarType> &dtypes,
std::optional<c10::Device> device = std::nullopt,
std::optional<std::vector<size_t>> alignments = std::nullopt);

/***************************************************************************************************
* Cast
**************************************************************************************************/
Expand Down
86 changes: 86 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/allocate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <memory>
#include <vector>

#include "../extensions.h"

namespace transformer_engine {
namespace pytorch {

std::vector<at::Tensor> bulk_allocate(const std::vector<std::vector<size_t>> &shapes,
const std::vector<at::ScalarType> &dtypes,
std::optional<at::Device> device,
std::optional<std::vector<size_t>> alignments) {
// Check shapes and dtypes
const size_t n = shapes.size();
NVTE_CHECK(dtypes.size() == n, "Got ", shapes.size(), " shapes and ", dtypes.size(), " dtypes.");
NVTE_CHECK(!alignments || alignments->size() == n, "Got ", shapes.size(), " shapes and ",
alignments->size(), " alignments.");

// Return immediately if no tensors are needed
if (n == 0) return {};

// Set defaults for optional arguments
if (!device) {
device = at::Device(at::kCUDA);
}
if (!alignments) {
alignments = std::vector<size_t>{};
alignments->reserve(n);
for (const auto &dtype : dtypes) {
alignments->push_back(c10::elementSize(dtype));
}
}

// Compute offsets in base buffer
std::vector<size_t> byte_sizes(n);
std::vector<size_t> offsets(n);
size_t base_byte_size = 0;
size_t base_alignment = 1;
for (size_t i = 0; i < n; ++i) {
byte_sizes[i] = product(shapes[i]) * at::elementSize(dtypes[i]);
offsets[i] = roundup(base_byte_size, (*alignments)[i]);
base_byte_size = offsets[i] + byte_sizes[i];
base_alignment = std::max(base_alignment, (*alignments)[i]);
}
if (base_alignment > 1) {
// Pad in case data pointer is not aligned
base_byte_size += base_alignment;
}

// Allocate base buffer
auto base_buffer = std::make_shared<at::Tensor>(
at::empty({static_cast<int64_t>(base_byte_size)}, at::device(*device).dtype(torch::kUInt8)));
uint8_t *base_ptr = base_buffer->data_ptr<uint8_t>();
base_ptr =
reinterpret_cast<uint8_t *>(roundup(reinterpret_cast<uintptr_t>(base_ptr), base_alignment));

// Create views into base buffer
std::vector<at::Tensor> out;
out.reserve(n);
std::vector<int64_t> shape_int64;
for (size_t i = 0; i < n; ++i) {
shape_int64.assign(shapes[i].begin(), shapes[i].end());
if (byte_sizes[i] == 0) {
// Work around problems with from_blob when constructing an
// empty tensor. Passing a null pointer fails because it checks
// that the pointer is on GPU. Passing a non-null pointer can
// cause bugs in TE kernels.
out.emplace_back(at::empty(shape_int64, at::device(*device).dtype(dtypes[i])));
} else {
// Construct tensor with custom deleter to keep base buffer alive
out.emplace_back(at::from_blob(
base_ptr + offsets[i], shape_int64, [base_buffer](void *) {},
at::device(*device).dtype(dtypes[i])));
}
}
return out;
}

} // namespace pytorch
} // namespace transformer_engine
Loading
Loading