From dd644d79036e9849dca0ea9c871d6ef28b64451c Mon Sep 17 00:00:00 2001 From: gongchensu Date: Wed, 6 May 2026 07:55:44 +0000 Subject: [PATCH] feat(cuda): add cat support for CUDA-like backends Add a generic CUDA cat kernel plus NVIDIA/Iluvatar/MetaX/Moore specializations, and cover empty-output regression in tests. --- src/cuda/cat/kernel.cuh | 40 +++++++++ src/cuda/cat/kernel.h | 148 +++++++++++++++++++++++++++++++++ src/cuda/iluvatar/cat/kernel.h | 21 +++++ src/cuda/metax/cat/kernel.h | 21 +++++ src/cuda/moore/cat/kernel.h | 26 ++++++ src/cuda/nvidia/cat/kernel.h | 21 +++++ tests/test_cat.py | 4 + 7 files changed, 281 insertions(+) create mode 100644 src/cuda/cat/kernel.cuh create mode 100644 src/cuda/cat/kernel.h create mode 100644 src/cuda/iluvatar/cat/kernel.h create mode 100644 src/cuda/metax/cat/kernel.h create mode 100644 src/cuda/moore/cat/kernel.h create mode 100644 src/cuda/nvidia/cat/kernel.h diff --git a/src/cuda/cat/kernel.cuh b/src/cuda/cat/kernel.cuh new file mode 100644 index 00000000..8d310b5b --- /dev/null +++ b/src/cuda/cat/kernel.cuh @@ -0,0 +1,40 @@ +#ifndef INFINI_OPS_CUDA_CAT_KERNEL_CUH_ +#define INFINI_OPS_CUDA_CAT_KERNEL_CUH_ + +#include "cuda/kernel_commons.cuh" + +namespace infini::ops { + +template +__global__ void CatKernel(T* __restrict__ out, + const T* const* __restrict__ inputs, + const size_t* __restrict__ input_dim_sizes, + const size_t* __restrict__ input_dim_offsets, + size_t input_count, size_t out_dim_size, size_t inner, + size_t output_size) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < output_size) { + size_t outer_idx = idx / (out_dim_size * inner); + size_t rem = idx % (out_dim_size * inner); + size_t dim_idx = rem / inner; + size_t inner_idx = rem % inner; + + size_t input_idx = 0; + while (input_idx + 1 < input_count && + dim_idx >= input_dim_offsets[input_idx + 1]) { + ++input_idx; + } + + size_t local_dim_idx = dim_idx - input_dim_offsets[input_idx]; + size_t src_idx = + (outer_idx * input_dim_sizes[input_idx] + local_dim_idx) * inner + + inner_idx; + + out[idx] = inputs[input_idx][src_idx]; + } +} + +} // namespace infini::ops + +#endif diff --git a/src/cuda/cat/kernel.h b/src/cuda/cat/kernel.h new file mode 100644 index 00000000..8fd19852 --- /dev/null +++ b/src/cuda/cat/kernel.h @@ -0,0 +1,148 @@ +#ifndef INFINI_OPS_CUDA_CAT_KERNEL_H_ +#define INFINI_OPS_CUDA_CAT_KERNEL_H_ + +#include +#include +#include +#include +#include + +#include "base/cat.h" +#include "common/generic_utils.h" +#include "cuda/cat/kernel.cuh" +#include "cuda/kernel_commons.cuh" +#include "cuda/runtime_utils.h" + +namespace infini::ops { + +template +class CudaCat : public Cat { + public: + CudaCat(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) + : Cat{first_input, rest_inputs, dim, out}, + out_shape_{out.shape()}, + out_dtype_{out.dtype()}, + output_size_{out.numel()} { + assert(out.IsContiguous() && + "CudaCat currently requires contiguous output"); + assert(first_input.IsContiguous() && + "CudaCat currently requires contiguous inputs"); + assert(first_input.dtype() == out_dtype_); + + input_dim_sizes_.reserve(input_count_); + input_dim_offsets_.reserve(input_count_); + + input_dim_offsets_.push_back(0); + input_dim_sizes_.push_back(first_input.shape()[dim_]); + for (const auto& t : rest_inputs) { + assert(t.IsContiguous() && + "CudaCat currently requires contiguous inputs"); + assert(t.dtype() == out_dtype_); + input_dim_offsets_.push_back(input_dim_offsets_.back() + + input_dim_sizes_.back()); + input_dim_sizes_.push_back(t.shape()[dim_]); + } + + inner_ = 1; + for (size_t i = static_cast(dim_) + 1; i < out_shape_.size(); ++i) { + inner_ *= out_shape_[i]; + } + out_dim_size_ = out_shape_[dim_]; + + size_t count_bytes = input_count_ * sizeof(*d_input_ptrs_); + size_t dim_size_bytes = input_count_ * sizeof(*d_input_dim_sizes_); + size_t dim_offset_bytes = input_count_ * sizeof(*d_input_dim_offsets_); + const size_t metadata_size = + count_bytes + dim_size_bytes + dim_offset_bytes; + std::vector metadata(metadata_size); + + Backend::Malloc((void**)&d_metadata_, metadata_size); + + size_t offset = 0; + d_input_ptrs_ = reinterpret_cast(d_metadata_ + offset); + offset += count_bytes; + + d_input_dim_sizes_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, input_dim_sizes_.data(), + dim_size_bytes); + offset += dim_size_bytes; + + d_input_dim_offsets_ = + reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, input_dim_offsets_.data(), + dim_offset_bytes); + + Backend::Memcpy(d_metadata_ + count_bytes, metadata.data() + count_bytes, + dim_size_bytes + dim_offset_bytes, + Backend::MemcpyHostToDevice); + } + + ~CudaCat() { Backend::Free(d_metadata_); } + + void operator()(const Tensor first_input, std::vector rest_inputs, + int64_t /*dim*/, Tensor out) const override { + if (output_size_ == 0) { + return; + } + + std::vector input_ptrs; + input_ptrs.reserve(input_count_); + input_ptrs.push_back(first_input.data()); + for (const auto& t : rest_inputs) { + input_ptrs.push_back(t.data()); + } + + Backend::Memcpy(d_input_ptrs_, input_ptrs.data(), + input_count_ * sizeof(*d_input_ptrs_), + Backend::MemcpyHostToDevice); + + int block_size = RuntimeUtils::GetOptimalBlockSize(); + DispatchFunc( + {static_cast(out_dtype_), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + dim3 blockDims( + std::min(static_cast(block_size), output_size_)); + dim3 gridDims(utils::CeilDiv(output_size_, blockDims.x)); + + CatKernel<<>>( + reinterpret_cast(out.data()), + reinterpret_cast(d_input_ptrs_), + d_input_dim_sizes_, d_input_dim_offsets_, input_count_, + out_dim_size_, inner_, output_size_); + }, + "CudaCat::operator()"); + } + + private: + std::byte* d_metadata_{nullptr}; + + const void** d_input_ptrs_{nullptr}; + + Tensor::Size* d_input_dim_sizes_{nullptr}; + + Tensor::Size* d_input_dim_offsets_{nullptr}; + + Tensor::Shape out_shape_; + + DataType out_dtype_; + + Tensor::Size output_size_{0}; + + Tensor::Size inner_{0}; + + Tensor::Size out_dim_size_{0}; + + std::vector input_dim_sizes_; + + std::vector input_dim_offsets_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/iluvatar/cat/kernel.h b/src/cuda/iluvatar/cat/kernel.h new file mode 100644 index 00000000..345aa3d1 --- /dev/null +++ b/src/cuda/iluvatar/cat/kernel.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_ILUVATAR_CAT_KERNEL_H_ +#define INFINI_OPS_ILUVATAR_CAT_KERNEL_H_ + +#include + +#include "cuda/cat/kernel.h" +#include "cuda/iluvatar/caster.cuh" +#include "cuda/iluvatar/runtime_.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaCat> { + public: + using CudaCat>::CudaCat; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/metax/cat/kernel.h b/src/cuda/metax/cat/kernel.h new file mode 100644 index 00000000..0f0d29fc --- /dev/null +++ b/src/cuda/metax/cat/kernel.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_METAX_CAT_KERNEL_H_ +#define INFINI_OPS_METAX_CAT_KERNEL_H_ + +#include + +#include "cuda/cat/kernel.h" +#include "cuda/metax/caster.cuh" +#include "cuda/metax/runtime_.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaCat> { + public: + using CudaCat>::CudaCat; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/moore/cat/kernel.h b/src/cuda/moore/cat/kernel.h new file mode 100644 index 00000000..b6850d7c --- /dev/null +++ b/src/cuda/moore/cat/kernel.h @@ -0,0 +1,26 @@ +#ifndef INFINI_OPS_MOORE_CAT_KERNEL_H_ +#define INFINI_OPS_MOORE_CAT_KERNEL_H_ + +#include + +// clang-format off +#include "cuda/moore/polyfills.cuh" +// clang-format on + +#include "cuda/cat/kernel.h" +#include "cuda/moore/caster.cuh" +#include "cuda/moore/polyfills.cuh" +#include "cuda/moore/runtime_.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaCat> { + public: + using CudaCat>::CudaCat; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/nvidia/cat/kernel.h b/src/cuda/nvidia/cat/kernel.h new file mode 100644 index 00000000..90eb557e --- /dev/null +++ b/src/cuda/nvidia/cat/kernel.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_NVIDIA_CAT_KERNEL_H_ +#define INFINI_OPS_NVIDIA_CAT_KERNEL_H_ + +#include + +#include "cuda/cat/kernel.h" +#include "cuda/nvidia/caster.cuh" +#include "cuda/nvidia/runtime_.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaCat> { + public: + using CudaCat>::CudaCat; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_cat.py b/tests/test_cat.py index 85428b53..773cf1b4 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -23,6 +23,10 @@ (((2, 4, 32), (2, 4, 64)), 2, (2, 4, 96)), # 4 inputs, dim=1 (((1, 1024), (1, 1024), (1, 1024), (1, 1024)), 1, (1, 4096)), + # 2 inputs, empty cat dim + (((2, 0, 32), (2, 0, 32)), 1, (2, 0, 32)), + # 2 inputs, empty non-cat dim + (((0, 32), (0, 64)), 1, (0, 96)), ), ) @pytest.mark.parametrize(