From e308e90ddc418d91d1449869182b184d3dfd3b28 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Fri, 3 Apr 2026 12:17:30 +0000 Subject: [PATCH 01/23] feat: add data type related interface and implementation as well as OpenMPI's type map. - Add `constexpr_map.h` which provides a compile-time map structure. - Add `include/data_type.h` which includes `infiniDataType_t` that is exposed to the public. - Add `src/data_type_impl.h` which includes `DataType` and related constructs that are used internally. - Add `src/ompi/type_map.h` which contains mappings that OpenMPI needs, specifically data type mapping at this moment. --- include/data_type.h | 33 +++++++++++++++++++ src/constexpr_map.h | 32 +++++++++++++++++++ src/data_type_impl.h | 76 ++++++++++++++++++++++++++++++++++++++++++++ src/ompi/type_map.h | 31 ++++++++++++++++++ 4 files changed, 172 insertions(+) create mode 100644 include/data_type.h create mode 100644 src/constexpr_map.h create mode 100644 src/data_type_impl.h create mode 100644 src/ompi/type_map.h diff --git a/include/data_type.h b/include/data_type.h new file mode 100644 index 0000000..f39c409 --- /dev/null +++ b/include/data_type.h @@ -0,0 +1,33 @@ +#ifndef INFINI_CCL_DATA_TYPE_H_ +#define INFINI_CCL_DATA_TYPE_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum { + infiniChar = 0, + infiniInt8 = 0, + infiniInt16 = 1, + infiniInt = 2, + infiniInt32 = 2, + infiniInt64 = 3, + infiniUInt8 = 4, + infiniUInt16 = 5, + infiniUInt32 = 6, + infiniUInt64 = 7, + infiniHalf = 8, + infiniFloat16 = 8, + infiniBFloat16 = 9, + infiniFloat = 10, + infiniFloat32 = 10, + infiniDouble = 11, + infiniFloat64 = 11, + infiniNumTypes = 12, +} infiniDataType_t; + +#ifdef __cplusplus +} +#endif + +#endif // INFINI_CCL_DATA_TYPE_H_ diff --git a/src/constexpr_map.h b/src/constexpr_map.h new file mode 100644 index 0000000..251574c --- /dev/null +++ b/src/constexpr_map.h @@ -0,0 +1,32 @@ +#ifndef INFINI_CCL_CONSTEXPR_MAP_H_ +#define INFINI_CCL_CONSTEXPR_MAP_H_ + +#include +#include +#include +#include + +namespace infini::ccl { + +template struct ConstexprMap { + constexpr ConstexprMap(std::array, size> data) + : data_(data) {} + + constexpr Value at(Key key) const { + for (const auto &pr : data_) { + if (pr.first == key) + return pr.second; + } + // TODO(lzm): change to logging. + assert("the key is not found in the `ConstexprMap`"); + // Unreachable, provided to satisfy the compiler's requirement. + std::abort(); + } + +private: + std::array, size> data_; +}; + +} // namespace infini::ccl + +#endif diff --git a/src/data_type_impl.h b/src/data_type_impl.h new file mode 100644 index 0000000..ed119bd --- /dev/null +++ b/src/data_type_impl.h @@ -0,0 +1,76 @@ +#ifndef INFINI_CCL_DATA_TYPE_IMPL_H_ +#define INFINI_CCL_DATA_TYPE_IMPL_H_ + +#include +#include + +#include "constexpr_map.h" +#include "data_type.h" + +namespace infini::ccl { + +using DataType = ::infiniDataType_t; + +constexpr DataType kChar = infiniChar; +constexpr DataType kInt8 = infiniInt8; +constexpr DataType kInt16 = infiniInt16; +constexpr DataType kInt32 = infiniInt32; +constexpr DataType kInt64 = infiniInt64; +constexpr DataType kUInt8 = infiniUInt8; +constexpr DataType kUInt16 = infiniUInt16; +constexpr DataType kUInt32 = infiniUInt32; +constexpr DataType kUInt64 = infiniUInt64; +constexpr DataType kFloat16 = infiniFloat16; +constexpr DataType kBFloat16 = infiniBFloat16; +constexpr DataType kFloat32 = infiniFloat32; +constexpr DataType kFloat64 = infiniFloat64; +constexpr DataType kNumTypes = infiniNumTypes; + +constexpr ConstexprMap kDataTypeToSize{{{ + {kInt8, 1}, + {kInt16, 2}, + {kInt32, 4}, + {kInt64, 8}, + {kUInt8, 1}, + {kUInt16, 2}, + {kUInt32, 4}, + {kUInt64, 8}, + {kFloat16, 2}, + {kBFloat16, 2}, + {kFloat32, 4}, + {kFloat64, 8}, +}}}; + +constexpr ConstexprMap kDataTypeToDesc{{{ + {kInt8, "int8"}, + {kInt16, "int16"}, + {kInt32, "int32"}, + {kInt64, "int64"}, + {kUInt8, "uint8"}, + {kUInt16, "uint16"}, + {kUInt32, "uint32"}, + {kUInt64, "uint64"}, + {kFloat16, "float16"}, + {kBFloat16, "bfloat16"}, + {kFloat32, "float32"}, + {kFloat64, "float64"}, +}}}; + +constexpr ConstexprMap kStringToDataType{{{ + {"int8", kInt8}, + {"int16", kInt16}, + {"int32", kInt32}, + {"int64", kInt64}, + {"uint8", kUInt8}, + {"uint16", kUInt16}, + {"uint32", kUInt32}, + {"uint64", kUInt64}, + {"float16", kFloat16}, + {"bfloat16", kBFloat16}, + {"float32", kFloat32}, + {"float64", kFloat64}, +}}}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_DATA_TYPE_IMPL_H_ diff --git a/src/ompi/type_map.h b/src/ompi/type_map.h new file mode 100644 index 0000000..c8170c2 --- /dev/null +++ b/src/ompi/type_map.h @@ -0,0 +1,31 @@ +#ifndef INFINI_CCL_OMPI_TYPE_MAPPING_H_ +#define INFINI_CCL_OMPI_TYPE_MAPPING_H_ + +#include + +#include "data_type_impl.h" + +namespace infini::ccl { + +static const ConstexprMap kOmpiTypeMap{{{ + {kInt8, MPI_INT8_T}, + {kInt16, MPI_INT16_T}, + {kInt32, MPI_INT32_T}, + {kInt64, MPI_INT64_T}, + {kUInt8, MPI_UINT8_T}, + {kUInt16, MPI_UINT16_T}, + {kUInt32, MPI_UINT32_T}, + {kUInt64, MPI_UINT64_T}, + {kFloat32, MPI_FLOAT}, + {kFloat64, MPI_DOUBLE}, + {kFloat16, MPI_BYTE}, + {kBFloat16, MPI_BYTE}, +}}}; + +inline MPI_Datatype DataTypeToOmpiType(DataType dtype) { + return kOmpiTypeMap.at(dtype); +} + +} // namespace infini::ccl + +#endif From 83c03340f5632a13aaa9a34115c23304fffbcf3c Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Fri, 3 Apr 2026 12:22:41 +0000 Subject: [PATCH 02/23] feat: add the external and internal interfaces for return code/status. - Add `include/return_status.h` which contains the public interface for return codes/status codes. - Add `src/return_status_impl.h` which contains the private/internal interface for return codes/status codes. --- include/return_status.h | 24 ++++++++++++++++++++++++ src/return_status_impl.h | 22 ++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 include/return_status.h create mode 100644 src/return_status_impl.h diff --git a/include/return_status.h b/include/return_status.h new file mode 100644 index 0000000..4e6bf6b --- /dev/null +++ b/include/return_status.h @@ -0,0 +1,24 @@ +#ifndef INFINI_CCL_RETURN_STATUS_H_ +#define INFINI_CCL_RETURN_STATUS_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum { + infiniSuccess = 0, + infiniUnhandledError = 1, + infiniSystemError = 2, + infiniInternalError = 3, + infiniInvalidArgument = 4, + infiniInvalidUsage = 5, + infiniRemoteError = 6, + infiniInProgress = 7, + infiniNumResults = 8 +} infiniResult_t; + +#ifdef __cplusplus +} +#endif + +#endif // INFINI_CCL_RETURN_STATUS_H_ diff --git a/src/return_status_impl.h b/src/return_status_impl.h new file mode 100644 index 0000000..db7d051 --- /dev/null +++ b/src/return_status_impl.h @@ -0,0 +1,22 @@ +#ifndef INFINI_CCL_RETURN_STATUS_IMPL_H_ +#define INFINI_CCL_RETURN_STATUS_IMPL_H_ + +#include "return_status.h" + +namespace infini::ccl { + +using ReturnStatus = ::infiniResult_t; + +constexpr ReturnStatus kSuccess = infiniSuccess; +constexpr ReturnStatus kUnhandledError = infiniUnhandledError; +constexpr ReturnStatus kSystemError = infiniSystemError; +constexpr ReturnStatus kInternalError = infiniInternalError; +constexpr ReturnStatus kInvalidArgument = infiniInvalidArgument; +constexpr ReturnStatus kInvalidUsage = infiniInvalidUsage; +constexpr ReturnStatus kRemoteError = infiniRemoteError; +constexpr ReturnStatus kInProgress = infiniInProgress; +constexpr ReturnStatus kNumResults = infiniNumResults; + +} // namespace infini::ccl + +#endif // INFINI_CCL_RETURN_STATUS_IMPL_H_ From 2f539ad740a91c4d9b66de5753f3a29f1694a33c Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Fri, 3 Apr 2026 12:26:17 +0000 Subject: [PATCH 03/23] feat: add the definitions of device, dispatcher, and compile-time traits. - Add `src/device.h` which contains the definitions and utils about devices. - Add `src/traits.h` which contains the compile-time traits. - Add `src/dispatcher.h` which contains the implementation of C++17-compatible dispatcher. - These are mainly migrated from `InfiniTensor/InfiniOps`. --- src/device.h | 119 ++++++++++++++++++ src/dispatcher.h | 315 +++++++++++++++++++++++++++++++++++++++++++++++ src/traits.h | 160 ++++++++++++++++++++++++ 3 files changed, 594 insertions(+) create mode 100644 src/device.h create mode 100644 src/dispatcher.h create mode 100644 src/traits.h diff --git a/src/device.h b/src/device.h new file mode 100644 index 0000000..c726a6f --- /dev/null +++ b/src/device.h @@ -0,0 +1,119 @@ +#ifndef INFINI_CCL_DEVICE_H_ +#define INFINI_CCL_DEVICE_H_ + +#include + +#include "constexpr_map.h" +#include "traits.h" + +namespace infini::ccl { + +class Device { +public: + enum class Type { + kCpu = 0, + kNvidia = 1, + kCambricon = 2, + kAscend = 3, + kMetax = 4, + kMoore = 5, + kIluvatar = 6, + kKunlun = 7, + kHygon = 8, + kQy = 9, + kCount + }; + + Device() = default; + + Device(const Type &type, const int &index = 0) : type_{type}, index_{index} {} + + static const Type TypeFromString(const std::string &name) { + return kDescToDevice.at(name); + } + + static const std::string_view StringFromType(const Type &type) { + return kDeviceToDesc.at(type); + } + + const Type &type() const { return type_; } + + const int &index() const { return index_; } + + std::string ToString() const { + return std::string{StringFromType(type_)} + ":" + std::to_string(index_); + } + + bool operator==(const Device &other) const { + return type_ == other.type_ && index_ == other.index_; + } + + bool operator!=(const Device &other) const { return !(*this == other); } + +private: + Type type_{Type::kCpu}; + + static constexpr ConstexprMap(Device::Type::kCount)> + kDeviceToDesc{{{ + {Type::kCpu, "cpu"}, + {Type::kNvidia, "nvidia"}, + {Type::kCambricon, "cambricon"}, + {Type::kAscend, "ascend"}, + {Type::kMetax, "metax"}, + {Type::kMoore, "moore"}, + {Type::kIluvatar, "iluvatar"}, + {Type::kKunlun, "kunlun"}, + {Type::kHygon, "hygon"}, + {Type::kQy, "qy"}, + }}}; + + static constexpr ConstexprMap(Device::Type::kCount)> + kDescToDevice{{{ + {"cpu", Type::kCpu}, + {"nvidia", Type::kNvidia}, + {"cambricon", Type::kCambricon}, + {"ascend", Type::kAscend}, + {"metax", Type::kMetax}, + {"moore", Type::kMoore}, + {"iluvatar", Type::kIluvatar}, + {"kunlun", Type::kKunlun}, + {"hygon", Type::kHygon}, + {"qy", Type::kQy}, + }}}; + + int index_{0}; +}; + +// Primary template: Devices are disabled by default. Platform-specific +// headers (e.g. `cpu/device_.h`) specialize this to `std::true_type`. +template struct DeviceEnabled : std::false_type {}; + +// Defines the common categories of devices using List. +using AllDeviceTypes = + List; + +// Deferred computation of active devices. The `Filter` and `FilterList` +// evaluation are nested inside a class template so that `DeviceEnabled` +// specializations from platform `device_.h` headers are visible at +// instantiation time. Use with a dependent type parameter +// (e.g. `ActiveDevices`) to ensure deferred instantiation. +template struct ActiveDevicesImpl { + struct Filter { + template + std::enable_if_t::value> + operator()(ValueTag) const {} + }; + + using type = typename FilterList, AllDeviceTypes>::type; +}; + +template using ActiveDevices = typename ActiveDevicesImpl::type; + +} // namespace infini::ccl + +#endif diff --git a/src/dispatcher.h b/src/dispatcher.h new file mode 100644 index 0000000..f90ca0e --- /dev/null +++ b/src/dispatcher.h @@ -0,0 +1,315 @@ +#ifndef INFINI_CCL_DISPATCHER_H_ +#define INFINI_CCL_DISPATCHER_H_ + +#include +#include +#include +#include + +#include "data_type_impl.h" +#include "device.h" +#include "traits.h" + +namespace infini::ccl { + +// ----------------------------------------------------------------------------- +// Core Generic Runtime Dispatchers +// ----------------------------------------------------------------------------- + +namespace detail { + +// Implements the dispatch body over a resolved `List`. +template +auto DispatchFuncImpl(ValueType value, Functor &&func, + std::string_view context_str, List, + Args &&...args) { + using ReturnType = decltype(std::forward(func)( + ValueTag(head)>{}, std::forward(args)...)); + + // Path for void functions. + if constexpr (std::is_void_v) { + bool handled = ((value == static_cast(tail) + ? (std::forward(func)( + ValueTag{}, std::forward(args)...), + true) + : false) || + ... || + (value == static_cast(head) + ? (std::forward(func)( + ValueTag{}, std::forward(args)...), + true) + : false)); + + if (!handled) { + // TODO(lzm): change to logging. + std::cerr << "dispatch error (void): value " << static_cast(value) + << " not supported in the context: " << context_str << "\n"; + std::abort(); + } + } + // Path for non-void functions. + else { + std::optional result; + bool handled = ((value == static_cast(tail) + ? (result.emplace(std::forward(func)( + ValueTag{}, std::forward(args)...)), + true) + : false) || + ... || + (value == static_cast(head) + ? (result.emplace(std::forward(func)( + ValueTag{}, std::forward(args)...)), + true) + : false)); + + if (handled) { + return *result; + } + // TODO(lzm): change to logging. + std::cerr << "dispatch error (non-void): value " << static_cast(value) + << " not supported in the context: " << context_str << "\n"; + std::abort(); + return ReturnType{}; + } +} + +// Deduces `head`/`tail` from a `List` type via partial specialization, +// then forwards to `DispatchFuncImpl`. +template +struct DispatchFuncUnwrap; + +template +struct DispatchFuncUnwrap, + std::tuple> { + static auto call(ValueType value, Functor &&func, + std::string_view context_str, Args &&...args) { + return DispatchFuncImpl(value, std::forward(func), context_str, + List{}, std::forward(args)...); + } +}; + +// Empty-list specialization +template +struct DispatchFuncUnwrap, std::tuple> { + static auto call(ValueType value, Functor &&, std::string_view context_str, + Args &&...) { + // TODO(lzm): change to logging. + std::cerr << "dispatch error: no allowed values registered for value " + << static_cast(value) + << " in the context: " << context_str << "\n"; + std::abort(); + } +}; + +} // namespace detail + +// (Single Dispatch) Dispatches a runtime value to a compile-time functor. +template +auto DispatchFunc(ValueType value, Functor &&func, + std::string_view context_str = "", Args &&...args) { + using FilteredPack = typename Filter, List<>, + all_values...>::type; + + return detail::DispatchFuncUnwrap< + ValueType, Functor, FilteredPack, + std::tuple>::call(value, std::forward(func), + context_str, std::forward(args)...); +} + +// (Multi-Dispatch) Dispatches a vector of runtime values to a compile-time +// functor. +// Base Case: All Dimensions Resolved +template +auto DispatchFunc(const std::vector &values, size_t /*index*/, + Functor &&func, std::string_view /*context_str*/, + List, Args &&...args) { + return std::forward(func)(List{}, + std::forward(args)...); +} + +// Forward declaration of the recursive multi-dispatch overload. +template +auto DispatchFunc(const std::vector &values, size_t index, + Functor &&func, std::string_view context_str, List, + Args &&...args); + +// Adapter used in the recursive multi-dispatch case: given a resolved value +// `val` recurse into the next dimension. +template +struct MultiDispatchRecurseAdapter; + +template +struct MultiDispatchRecurseAdapter, Functor, items...> { + const std::vector &values; + size_t next_index; + Functor &func; + std::string_view context_str; + + template + auto operator()(ValueTag, Args &&...args) const { + return DispatchFunc(values, next_index, func, context_str, + List{}, + std::forward(args)...); + } +}; + +template +auto MultiDispatchFirstDim(const std::vector &values, size_t index, + Functor &func, std::string_view context_str, + List, List, Args &&...args) { + static_assert(sizeof...(allowed) > 0, + "`DispatchFunc` dimension list is empty"); + using EnumType = std::common_type_t; + + MultiDispatchRecurseAdapter adapter{ + values, index + 1, func, context_str}; + + return DispatchFunc( + static_cast(values.at(index)), adapter, context_str, + std::forward(args)...); +} + +// (Multi-Dispatch) Recursive Case +template +auto DispatchFunc(const std::vector &values, size_t index, + Functor &&func, std::string_view context_str, List, + Args &&...args) { + return MultiDispatchFirstDim>( + values, index, func, context_str, List{}, FirstList{}, + std::forward(args)...); +} + +// ----------------------------------------------------------------------------- +// High-Level Specialized Dispatchers +// ----------------------------------------------------------------------------- +// These provide cleaner and more convenient APIs for common InfiniOps types. + +namespace detail { + +// Bridges the generic value dispatch layer to the `DataType`-specific type +// dispatch layer. +template struct DataTypeAdapter { + Functor &func; + + template + auto operator()(ValueTag, Args &&...args) const { + using T = TypeMapType(dtype)>; + return func(TypeTag{}, std::forward(args)...); + } +}; + +template struct DataTypeMultiAdapter { + Functor &func; + + template + auto operator()(List, Args &&...args) const { + return func(TypeTag(dtypes)>>{}..., + std::forward(args)...); + } +}; + +template struct DeviceAdapter { + Functor &func; + + template + auto operator()(ValueTag, Args &&...args) const { + return func(ValueTag{}, std::forward(args)...); + } +}; + +template struct DeviceMultiAdapter { + Functor &func; + + template + auto operator()(List, Args &&...args) const { + return func(ValueTag{}..., std::forward(args)...); + } +}; + +} // namespace detail + +// `DataType` Dispatch +template +auto DispatchFunc(DataType dtype, Functor &&func, + std::string_view context_str = "", Args &&...args) { + detail::DataTypeAdapter> adapter{func}; + return DispatchFunc(dtype, adapter, context_str, + std::forward(args)...); +} + +// `DataType` Multi-Dispatch +template +auto DispatchFunc(std::initializer_list dtypes, Functor &&func, + std::string_view context_str = "", Args &&...args) { + std::vector v; + for (auto d : dtypes) + v.push_back(static_cast(d)); + + detail::DataTypeMultiAdapter> adapter{func}; + return DispatchFunc(v, 0, adapter, context_str, List<>{}, + std::forward(args)...); +} + +// `Device` Dispatch +template +auto DispatchFunc(Device::Type device, Functor &&func, + std::string_view context_str = "", Args &&...args) { + detail::DeviceAdapter> adapter{func}; + return DispatchFunc(allowed_devices)...>( + device, adapter, context_str, std::forward(args)...); +} + +// `Device` Multi-Dispatch +template +auto DispatchFunc(std::initializer_list devices, Functor &&func, + std::string_view context_str = "", Args &&...args) { + std::vector v; + for (auto d : devices) + v.push_back(static_cast(d)); + + detail::DeviceMultiAdapter> adapter{func}; + return DispatchFunc(v, 0, adapter, context_str, List<>{}, + std::forward(args)...); +} + +template +auto DispatchFuncListAliasImpl(ValueType value, Functor &&func, + std::string_view context_str, List, + Args &&...args) { + return DispatchFunc>(items)...>( + value, std::forward(func), context_str, + std::forward(args)...); +} + +// Interface for Generic `List` Aliases +template ::value>> +auto DispatchFunc(ValueType value, Functor &&func, + std::string_view context_str = "", Args &&...args) { + return DispatchFuncListAliasImpl(value, std::forward(func), + context_str, ListType{}, + std::forward(args)...); +} + +// Interface for Any `int64_t`-Convertible Types +template +auto DispatchFunc(std::initializer_list keys, Functor &&func, + std::string_view context_str = "", Args &&...args) { + std::vector v_keys(keys); + return DispatchFunc(v_keys, 0, std::forward(func), + context_str, List<>{}, + std::forward(args)...); +} + +} // namespace infini::ccl + +#endif // INFINI_CCL_DISPATCHER_H_ diff --git a/src/traits.h b/src/traits.h new file mode 100644 index 0000000..68d6bbd --- /dev/null +++ b/src/traits.h @@ -0,0 +1,160 @@ +#ifndef INFINI_CCL_TRAITS_H_ +#define INFINI_CCL_TRAITS_H_ + +#include +#include + +namespace infini::ccl { + +// --------------------- List and TypePack --------------------- +// A generic container for a sequence of compile-time values. +template struct List {}; + +// `ListGet(List{})` extracts the `i`th value from a `List` +// tag. +template +constexpr auto ListGetImpl(List) { + if constexpr (index == 0) + return head; + else + return ListGetImpl(List{}); +} + +template +constexpr auto ListGet(List list) { + return ListGetImpl(list); +} + +template struct TypePack {}; + +// ----------------------------------------------------------------------------- +// Tags +// ----------------------------------------------------------------------------- +// Tags are passed as regular function arguments to user functors instead of +// template parameters. This lets users write plain C++17 `[](auto tag)` lambdas +// rather than C++20 template lambdas (`[]()`). + +// `TypeTag`: carries a C++ type. Recover with `typename +// decltype(tag)::type`. +template struct TypeTag { + using type = T; +}; + +// `ValueTag`: carries a compile-time value. Recover with +// `decltype(tag)::value`. +template struct ValueTag { + using value_type = decltype(v); + static constexpr auto value = v; +}; + +// ----------------------------------------------------------------------------- +// List Queries +// ----------------------------------------------------------------------------- + +// Check at compile-time if a value exists within a construct (e.g., `List<>`). +// Example: `static_assert(ContainsValue)`; +template struct Contains; + +template +struct Contains, value> + : std::disjunction...> {}; + +template +inline constexpr bool ContainsValue = Contains::value; + +// Check at compile-time if a type `T` is present in a variadic list of types +// `Ts`. +// Example: `static_assert(IsTypeInList)`; +template +inline constexpr bool IsTypeInList = (std::is_same_v || ...); + +// Trait to detect whether `T` is a `List<...>` specialization. +template struct IsListType : std::false_type {}; + +template struct IsListType> : std::true_type {}; + +// ----------------------------------------------------------------------------- +// List Operations +// ----------------------------------------------------------------------------- + +// Concatenates two List types into a single `List`. +// Example: `ConcatType, List<3, 4>>` is `List<1, 2, 3, 4>`. +template struct Concat; + +template +struct Concat, List> { + using type = List; +}; + +template +using ConcatType = typename Concat::type; + +template struct Flatten; + +template struct Flatten> { + using type = List; +}; + +template +struct Flatten { + using type = typename Flatten, Rest...>::type; +}; + +// ----------------------------------------------------------------------------- +// Invocability Detection (SFINAE) +// ----------------------------------------------------------------------------- + +// Checks if a `Functor` can be called with a `ValueTag` and `Args...`. +template +struct IsInvocable : std::false_type {}; + +template +struct IsInvocable()( + ValueTag{}, std::declval()...))>, + Args...> : std::true_type {}; + +template +inline constexpr bool IsInvocableValue = + IsInvocable::value; + +// ----------------------------------------------------------------------------- +// Filtering Logic +// ----------------------------------------------------------------------------- + +// Recursive template to filter values based on `Functor` support at +// compile-time. +template +struct Filter; + +// Base case: All values processed. +template +struct Filter, List> { + using type = List; +}; + +// Recursive step: Test the `head` value and accumulate if supported. +template +struct Filter, List, head, tail...> { + using type = typename std::conditional_t< + IsInvocableValue && + !ContainsValue, head>, + Filter, List, tail...>, + Filter, List, tail...>>::type; +}; + +// Interface to filter a `List` type directly. +template +struct FilterList; + +template +struct FilterList, List> { + using type = + typename Filter, List<>, items...>::type; +}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_TRAITS_H_ From 175837af568a25435d6763a8ceec4e1ae2bcd0b2 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Thu, 9 Apr 2026 13:05:51 +0000 Subject: [PATCH 04/23] feat: add `src/backend.h` which contains the definition of `BackendType` --- src/backend.h | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 src/backend.h diff --git a/src/backend.h b/src/backend.h new file mode 100644 index 0000000..bdd7413 --- /dev/null +++ b/src/backend.h @@ -0,0 +1,21 @@ +#ifndef INFINI_CCL_BACKEND_H_ +#define INFINI_CCL_BACKEND_H_ + +#include + +namespace infini::ccl { + +enum class BackendType : int8_t { + kOmpi = 0, + kGloo = 1, + kNccl = 2, + kMccl = 3, + kRccl = 4, + kCncl = 5, + kHccl = 6, + kCount +}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_BACKEND_H_ From faa40a351c8c4a21957f310107fcb2de818f78fa Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Thu, 9 Apr 2026 13:07:15 +0000 Subject: [PATCH 05/23] feat: add runtime and platform specialization of `DeviceEnabled` for NVIDIA and MetaX - add runtime in `src/runtime.h` and its specializations under `cuda/`, `nvidia/` and `metax/` - add `device_.h` under `nvidia/` and `metax/` which contain their platform specializations for `DeviceEnabled` --- src/cuda/runtime_.h | 24 ++++++++++++++++++++ src/metax/device_.h | 12 ++++++++++ src/metax/runtime_.h | 37 ++++++++++++++++++++++++++++++ src/nvidia/device_.h | 12 ++++++++++ src/nvidia/runtime_.h | 41 ++++++++++++++++++++++++++++++++++ src/runtime.h | 52 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 178 insertions(+) create mode 100644 src/cuda/runtime_.h create mode 100644 src/metax/device_.h create mode 100644 src/metax/runtime_.h create mode 100644 src/nvidia/device_.h create mode 100644 src/nvidia/runtime_.h create mode 100644 src/runtime.h diff --git a/src/cuda/runtime_.h b/src/cuda/runtime_.h new file mode 100644 index 0000000..551f807 --- /dev/null +++ b/src/cuda/runtime_.h @@ -0,0 +1,24 @@ +#ifndef INFINI_CCL_CUDA_RUNTIME_H_ +#define INFINI_CCL_CUDA_RUNTIME_H_ + +#include + +#include "runtime.h" + +namespace infini::ccl { + +template struct CudaRuntime : DeviceRuntime { + static constexpr bool Validate() { + DeviceRuntime::Validate(); + static_assert( + std::is_invocable_v, + "`Runtime::Memcpy` must be callable with " + "`(void*, const void*, size_t, MemcpyHostToDevice)`."); + return true; + } +}; + +} // namespace infini::ccl + +#endif diff --git a/src/metax/device_.h b/src/metax/device_.h new file mode 100644 index 0000000..4d1ddae --- /dev/null +++ b/src/metax/device_.h @@ -0,0 +1,12 @@ +#ifndef INFINI_CCL_METAX_DEVICE__H_ +#define INFINI_CCL_METAX_DEVICE__H_ + +#include "device.h" + +namespace infini::ccl { + +template <> struct DeviceEnabled : std::true_type {}; + +} // namespace infini::ccl + +#endif \ No newline at end of file diff --git a/src/metax/runtime_.h b/src/metax/runtime_.h new file mode 100644 index 0000000..fc528d0 --- /dev/null +++ b/src/metax/runtime_.h @@ -0,0 +1,37 @@ +#ifndef INFINI_CCL_METAX_RUNTIME_H_ +#define INFINI_CCL_METAX_RUNTIME_H_ + +// clang-format off +#include +// clang-format on + +#include "cuda/runtime_.h" +#include "metax/device_.h" + +namespace infini::ccl { + +template <> +struct Runtime + : CudaRuntime> { + using Stream = mcStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kMetax; + + static constexpr auto Malloc = mcMalloc; + + static constexpr auto Memcpy = mcMemcpy; + + static constexpr auto Free = mcFree; + + static constexpr auto MemcpyHostToDevice = mcMemcpyHostToDevice; + + static constexpr auto MemcpyDeviceToHost = mcMemcpyDeviceToHost; + + static constexpr auto Memset = mcMemset; +}; + +static_assert(Runtime::Validate()); + +} // namespace infini::ccl + +#endif \ No newline at end of file diff --git a/src/nvidia/device_.h b/src/nvidia/device_.h new file mode 100644 index 0000000..e4b5b29 --- /dev/null +++ b/src/nvidia/device_.h @@ -0,0 +1,12 @@ +#ifndef INFINI_CCL_NVIDIA_DEVICE__H_ +#define INFINI_CCL_NVIDIA_DEVICE__H_ + +#include "device.h" + +namespace infini::ccl { + +template <> struct DeviceEnabled : std::true_type {}; + +} // namespace infini::ccl + +#endif diff --git a/src/nvidia/runtime_.h b/src/nvidia/runtime_.h new file mode 100644 index 0000000..9303826 --- /dev/null +++ b/src/nvidia/runtime_.h @@ -0,0 +1,41 @@ +#ifndef INFINI_CCL_NVIDIA_RUNTIME_H_ +#define INFINI_CCL_NVIDIA_RUNTIME_H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/runtime_.h" +#include "nvidia/device_.h" + +namespace infini::ccl { + +template <> +struct Runtime + : CudaRuntime> { + using Stream = cudaStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kNvidia; + + static constexpr auto Malloc = [](auto &&...args) { + return cudaMalloc(std::forward(args)...); + }; + + static constexpr auto Memcpy = cudaMemcpy; + + static constexpr auto Free = cudaFree; + + static constexpr auto MemcpyHostToDevice = cudaMemcpyHostToDevice; + + static constexpr auto MemcpyDeviceToHost = cudaMemcpyDeviceToHost; + + static constexpr auto Memset = cudaMemset; +}; + +static_assert(Runtime::Validate()); + +} // namespace infini::ccl + +#endif diff --git a/src/runtime.h b/src/runtime.h new file mode 100644 index 0000000..772d5a7 --- /dev/null +++ b/src/runtime.h @@ -0,0 +1,52 @@ +#ifndef INFINI_CCL_RUNTIME_H_ +#define INFINI_CCL_RUNTIME_H_ + +#include + +#include "device.h" + +namespace infini::ccl { + +template struct Runtime; + +/// ## Interface enforcement via CRTP. +/// +/// Inherit from the appropriate base to declare which interface level a +/// `Runtime` specialization implements. After the struct is fully defined, call +/// `static_assert(Runtime<...>::Validate())`. The chained `Validate()` checks +/// every required member's existence and signature at compile time, analogous +/// to how `override` catches signature mismatches for virtual functions. +/// +/// - `RuntimeBase`: `kDeviceType` only (e.g. CPU). +/// - `DeviceRuntime`: adds `Stream`, `Malloc`, and `Free` (e.g. Cambricon). + +/// Every Runtime must provide `static constexpr Device::Type kDeviceType`. +template struct RuntimeBase { + static constexpr bool Validate() { + static_assert( + std::is_same_v, + Device::Type>, + "`Runtime` must define `static constexpr Device::Type kDeviceType`."); + return true; + } +}; + +/// Runtimes with device memory must additionally provide `Stream`, `Malloc`, +/// and `Free`. +template struct DeviceRuntime : RuntimeBase { + static constexpr bool Validate() { + RuntimeBase::Validate(); + static_assert(sizeof(typename Derived::Stream) > 0, + "`Runtime` must define a `Stream` type alias."); + static_assert( + std::is_invocable_v, + "`Runtime::Malloc` must be callable with `(void**, size_t)`."); + static_assert(std::is_invocable_v, + "`Runtime::Free` must be callable with `(void*)`."); + return true; + } +}; + +} // namespace infini::ccl + +#endif From 109d36e1547d7aca2b97fe1f1a1df9353f9c79bc Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Fri, 10 Apr 2026 12:47:15 +0000 Subject: [PATCH 06/23] refactor: change `DataType` and `ReturnStatus` from aliasing to scoped enum class --- src/data_type_impl.h | 104 +++++++++++++++++++-------------------- src/return_status_impl.h | 24 ++++----- 2 files changed, 65 insertions(+), 63 deletions(-) diff --git a/src/data_type_impl.h b/src/data_type_impl.h index ed119bd..1b9e2fa 100644 --- a/src/data_type_impl.h +++ b/src/data_type_impl.h @@ -9,66 +9,66 @@ namespace infini::ccl { -using DataType = ::infiniDataType_t; - -constexpr DataType kChar = infiniChar; -constexpr DataType kInt8 = infiniInt8; -constexpr DataType kInt16 = infiniInt16; -constexpr DataType kInt32 = infiniInt32; -constexpr DataType kInt64 = infiniInt64; -constexpr DataType kUInt8 = infiniUInt8; -constexpr DataType kUInt16 = infiniUInt16; -constexpr DataType kUInt32 = infiniUInt32; -constexpr DataType kUInt64 = infiniUInt64; -constexpr DataType kFloat16 = infiniFloat16; -constexpr DataType kBFloat16 = infiniBFloat16; -constexpr DataType kFloat32 = infiniFloat32; -constexpr DataType kFloat64 = infiniFloat64; -constexpr DataType kNumTypes = infiniNumTypes; +enum class DataType : int8_t { + kChar = infiniChar, + kInt8 = infiniInt8, + kInt16 = infiniInt16, + kInt32 = infiniInt32, + kInt64 = infiniInt64, + kUInt8 = infiniUInt8, + kUInt16 = infiniUInt16, + kUInt32 = infiniUInt32, + kUInt64 = infiniUInt64, + kFloat16 = infiniFloat16, + kBFloat16 = infiniBFloat16, + kFloat32 = infiniFloat32, + kFloat64 = infiniFloat64, + kNumTypes = infiniNumTypes, +}; constexpr ConstexprMap kDataTypeToSize{{{ - {kInt8, 1}, - {kInt16, 2}, - {kInt32, 4}, - {kInt64, 8}, - {kUInt8, 1}, - {kUInt16, 2}, - {kUInt32, 4}, - {kUInt64, 8}, - {kFloat16, 2}, - {kBFloat16, 2}, - {kFloat32, 4}, - {kFloat64, 8}, + {DataType::kInt8, 1}, + {DataType::kInt16, 2}, + {DataType::kInt32, 4}, + {DataType::kInt64, 8}, + {DataType::kUInt8, 1}, + {DataType::kUInt16, 2}, + {DataType::kUInt32, 4}, + {DataType::kUInt64, 8}, + {DataType::kFloat16, 2}, + {DataType::kBFloat16, 2}, + {DataType::kFloat32, 4}, + {DataType::kFloat64, 8}, }}}; constexpr ConstexprMap kDataTypeToDesc{{{ - {kInt8, "int8"}, - {kInt16, "int16"}, - {kInt32, "int32"}, - {kInt64, "int64"}, - {kUInt8, "uint8"}, - {kUInt16, "uint16"}, - {kUInt32, "uint32"}, - {kUInt64, "uint64"}, - {kFloat16, "float16"}, - {kBFloat16, "bfloat16"}, - {kFloat32, "float32"}, - {kFloat64, "float64"}, + {DataType::kInt8, "int8"}, + {DataType::kInt16, "int16"}, + {DataType::kInt32, "int32"}, + {DataType::kInt64, "int64"}, + {DataType::kUInt8, "uint8"}, + {DataType::kUInt16, "uint16"}, + {DataType::kUInt32, "uint32"}, + {DataType::kUInt64, "uint64"}, + {DataType::kFloat16, "float16"}, + {DataType::kBFloat16, "bfloat16"}, + {DataType::kFloat32, "float32"}, + {DataType::kFloat64, "float64"}, }}}; constexpr ConstexprMap kStringToDataType{{{ - {"int8", kInt8}, - {"int16", kInt16}, - {"int32", kInt32}, - {"int64", kInt64}, - {"uint8", kUInt8}, - {"uint16", kUInt16}, - {"uint32", kUInt32}, - {"uint64", kUInt64}, - {"float16", kFloat16}, - {"bfloat16", kBFloat16}, - {"float32", kFloat32}, - {"float64", kFloat64}, + {"int8", DataType::kInt8}, + {"int16", DataType::kInt16}, + {"int32", DataType::kInt32}, + {"int64", DataType::kInt64}, + {"uint8", DataType::kUInt8}, + {"uint16", DataType::kUInt16}, + {"uint32", DataType::kUInt32}, + {"uint64", DataType::kUInt64}, + {"float16", DataType::kFloat16}, + {"bfloat16", DataType::kBFloat16}, + {"float32", DataType::kFloat32}, + {"float64", DataType::kFloat64}, }}}; } // namespace infini::ccl diff --git a/src/return_status_impl.h b/src/return_status_impl.h index db7d051..30db5e7 100644 --- a/src/return_status_impl.h +++ b/src/return_status_impl.h @@ -1,21 +1,23 @@ #ifndef INFINI_CCL_RETURN_STATUS_IMPL_H_ #define INFINI_CCL_RETURN_STATUS_IMPL_H_ +#include + #include "return_status.h" namespace infini::ccl { -using ReturnStatus = ::infiniResult_t; - -constexpr ReturnStatus kSuccess = infiniSuccess; -constexpr ReturnStatus kUnhandledError = infiniUnhandledError; -constexpr ReturnStatus kSystemError = infiniSystemError; -constexpr ReturnStatus kInternalError = infiniInternalError; -constexpr ReturnStatus kInvalidArgument = infiniInvalidArgument; -constexpr ReturnStatus kInvalidUsage = infiniInvalidUsage; -constexpr ReturnStatus kRemoteError = infiniRemoteError; -constexpr ReturnStatus kInProgress = infiniInProgress; -constexpr ReturnStatus kNumResults = infiniNumResults; +enum class ReturnStatus : int8_t { + kSuccess = infiniSuccess, + kUnhandledError = infiniUnhandledError, + kSystemError = infiniSystemError, + kInternalError = infiniInternalError, + kInvalidArgument = infiniInvalidArgument, + kInvalidUsage = infiniInvalidUsage, + kRemoteError = infiniRemoteError, + kInProgress = infiniInProgress, + kNumResults = infiniNumResults, +}; } // namespace infini::ccl From 5daec2071a2d88079bff8078c531e9fa1aa752af Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Fri, 10 Apr 2026 12:49:53 +0000 Subject: [PATCH 07/23] feat: add `Communicator` and `BackendCommInstance` - add `Communicator` and `BackendCommInstance` in `src/communicator.h` - add the backend-specific derived classes of `BackendCommInstance`, specifically `OmpiInstance` in `src/ompi/comm_instance.h` and `NcclInstance` in `src/nvidia/nccl/comm_instance.h` --- src/communicator.h | 69 +++++++++++++++++++++++++++++++++ src/nvidia/nccl/comm_instance.h | 17 ++++++++ src/ompi/comm_instance.h | 17 ++++++++ 3 files changed, 103 insertions(+) create mode 100644 src/communicator.h create mode 100644 src/nvidia/nccl/comm_instance.h create mode 100644 src/ompi/comm_instance.h diff --git a/src/communicator.h b/src/communicator.h new file mode 100644 index 0000000..7cc4aac --- /dev/null +++ b/src/communicator.h @@ -0,0 +1,69 @@ +#ifndef INFINI_CCL_COMMUNICATOR_H_ +#define INFINI_CCL_COMMUNICATOR_H_ + +#include + +#include "backend.h" +#include "device.h" + +namespace infini::ccl { + +struct BackendCommInstance { + virtual ~BackendCommInstance() = default; + BackendType type; +}; + +class Communicator { +public: + Communicator(Device::Type device_type, int device_id) + : device_type_(device_type), device_id_(device_id), global_rank_(-1), + global_size_(0) {} + + void set_world_info(int rank, int size) { + global_rank_ = rank; + global_size_ = size; + } + + auto intra_comm() const { return intra_comm_.get(); } + + auto inter_comm() const { return inter_comm_.get(); } + + BackendType intra_comm_backend() const { + return intra_comm_ ? intra_comm_->type : BackendType::kCount; + } + + BackendType inter_comm_backend() const { + return inter_comm_ ? inter_comm_->type : BackendType::kCount; + } + + int rank() const { return global_rank_; } + + int size() const { return global_size_; } + + int device_id() const { return device_id_; } + + Device::Type device_type() const { return device_type_; } + + bool HasBackend(BackendType t) const { + return (intra_comm_backend() == t) || (inter_comm_backend() == t); + } + +private: + // Slot 1: Intra-node (e.g., NCCL) + std::unique_ptr intra_comm_; + + // Slot 2: Inter-node (e.g., OpenMPI) + std::unique_ptr inter_comm_; + + int device_id_; + + int global_rank_; + + int global_size_; + + Device::Type device_type_; +}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_COMMUNICATOR_H_ diff --git a/src/nvidia/nccl/comm_instance.h b/src/nvidia/nccl/comm_instance.h new file mode 100644 index 0000000..12fe2fc --- /dev/null +++ b/src/nvidia/nccl/comm_instance.h @@ -0,0 +1,17 @@ +#ifndef INFINI_CCL_NVIDIA_NCCL_COMM_INSTANCE_H_ +#define INFINI_CCL_NVIDIA_NCCL_COMM_INSTANCE_H_ + +#include + +#include "communicator.h" + +namespace infini::ccl { + +struct NcclInstance : public BackendCommInstance { + ncclComm_t handle; + NcclInstance() { type = BackendType::kNccl; } +}; + +} // namespace infini::ccl + +#endif diff --git a/src/ompi/comm_instance.h b/src/ompi/comm_instance.h new file mode 100644 index 0000000..f2ea118 --- /dev/null +++ b/src/ompi/comm_instance.h @@ -0,0 +1,17 @@ +#ifndef INFINI_CCL_OMPI_COMM_INSTANCE_H_ +#define INFINI_CCL_OMPI_COMM_INSTANCE_H_ + +#include + +#include "communicator.h" + +namespace infini::ccl { + +struct OmpiInstance : public BackendCommInstance { + MPI_Comm handle; + OmpiInstance() { type = BackendType::kOmpi; } +}; + +} // namespace infini::ccl + +#endif From 7a03f5088b809883622a8cce5f0b38070d49a424 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Sat, 11 Apr 2026 10:03:36 +0000 Subject: [PATCH 08/23] style: add comments for the `#endif` in various files --- src/constexpr_map.h | 2 +- src/device.h | 2 +- src/metax/device_.h | 2 +- src/metax/runtime_.h | 6 +++--- src/ompi/comm_instance.h | 2 +- src/ompi/type_map.h | 2 +- src/runtime.h | 2 +- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/constexpr_map.h b/src/constexpr_map.h index 251574c..dc5d6f6 100644 --- a/src/constexpr_map.h +++ b/src/constexpr_map.h @@ -29,4 +29,4 @@ template struct ConstexprMap { } // namespace infini::ccl -#endif +#endif // INFINI_CCL_CONSTEXPR_MAP_H_ diff --git a/src/device.h b/src/device.h index c726a6f..b128cdc 100644 --- a/src/device.h +++ b/src/device.h @@ -116,4 +116,4 @@ template using ActiveDevices = typename ActiveDevicesImpl::type; } // namespace infini::ccl -#endif +#endif // INFINI_CCL_DEVICE_H_ diff --git a/src/metax/device_.h b/src/metax/device_.h index 4d1ddae..8660f0e 100644 --- a/src/metax/device_.h +++ b/src/metax/device_.h @@ -9,4 +9,4 @@ template <> struct DeviceEnabled : std::true_type {}; } // namespace infini::ccl -#endif \ No newline at end of file +#endif // INFINI_CCL_METAX_DEVICE__H_ diff --git a/src/metax/runtime_.h b/src/metax/runtime_.h index fc528d0..393d73b 100644 --- a/src/metax/runtime_.h +++ b/src/metax/runtime_.h @@ -1,5 +1,5 @@ -#ifndef INFINI_CCL_METAX_RUNTIME_H_ -#define INFINI_CCL_METAX_RUNTIME_H_ +#ifndef INFINI_CCL_METAX_RUNTIME__H_ +#define INFINI_CCL_METAX_RUNTIME__H_ // clang-format off #include @@ -34,4 +34,4 @@ static_assert(Runtime::Validate()); } // namespace infini::ccl -#endif \ No newline at end of file +#endif // INFINI_CCL_METAX_RUNTIME__H_ diff --git a/src/ompi/comm_instance.h b/src/ompi/comm_instance.h index f2ea118..792da89 100644 --- a/src/ompi/comm_instance.h +++ b/src/ompi/comm_instance.h @@ -14,4 +14,4 @@ struct OmpiInstance : public BackendCommInstance { } // namespace infini::ccl -#endif +#endif // INFINI_CCL_OMPI_COMM_INSTANCE_H_ diff --git a/src/ompi/type_map.h b/src/ompi/type_map.h index c8170c2..a3bef29 100644 --- a/src/ompi/type_map.h +++ b/src/ompi/type_map.h @@ -28,4 +28,4 @@ inline MPI_Datatype DataTypeToOmpiType(DataType dtype) { } // namespace infini::ccl -#endif +#endif // INFINI_CCL_OMPI_TYPE_MAPPING_H_ diff --git a/src/runtime.h b/src/runtime.h index 772d5a7..1e603c6 100644 --- a/src/runtime.h +++ b/src/runtime.h @@ -49,4 +49,4 @@ template struct DeviceRuntime : RuntimeBase { } // namespace infini::ccl -#endif +#endif // INFINI_CCL_RUNTIME_H_ From cfaf3fa8074ffb6cbc842a377ed9b4693ef70759 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Sat, 11 Apr 2026 11:01:28 +0000 Subject: [PATCH 09/23] feat: add priority levels for `BackendType` and `Device::Type` and `Operation` class for operation dispatching - add the generic traits for getting the "best" element in a `List` in `traits.h` - add traits for indicating enabled backends and `AllBackendTypes` alias - add Priority traits for `BackendType` and `Device::Type` in `backend.h` and `device.h`, respectively - add `src/operation.h` which contains `Operation` base class for all the operations and is responsible for dispatching different operations --- src/backend.h | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ src/device.h | 25 +++++++++++++++++++++++++ src/operation.h | 45 +++++++++++++++++++++++++++++++++++++++++++++ src/traits.h | 15 +++++++++++++++ 4 files changed, 133 insertions(+) create mode 100644 src/operation.h diff --git a/src/backend.h b/src/backend.h index bdd7413..37b5d9f 100644 --- a/src/backend.h +++ b/src/backend.h @@ -3,6 +3,8 @@ #include +#include "traits.h" + namespace infini::ccl { enum class BackendType : int8_t { @@ -16,6 +18,52 @@ enum class BackendType : int8_t { kCount }; +using AllBackendTypes = + List; + +template struct BackendEnabled : std::false_type {}; + +/** + * @brief Deferred computation of active backends for a specific operation Key. + */ +template struct ActiveBackendsImpl { + struct Filter { + template + std::enable_if_t< + BackendEnabled(kBackend)>::value> + operator()(ValueTag) const {} + }; + + using type = typename FilterList, AllBackendTypes>::type; +}; + +template +using ActiveBackends = typename ActiveBackendsImpl::type; + +/** + * @brief Priority trait for backend selection. + */ +template struct BackendPriority { + static constexpr int value = 0; +}; + +template <> struct BackendPriority { + static constexpr int value = 1; +}; + +template <> struct BackendPriority { + static constexpr int value = 10; +}; + +template +constexpr BackendType ListGetBestBackend(ActiveBackends) { + static_assert(ListSize>::value > 0, + "No backends enabled for this operation."); + return ListGetMax(ActiveBackends{}); +} + } // namespace infini::ccl #endif // INFINI_CCL_BACKEND_H_ diff --git a/src/device.h b/src/device.h index b128cdc..6c2c4ef 100644 --- a/src/device.h +++ b/src/device.h @@ -114,6 +114,31 @@ template struct ActiveDevicesImpl { template using ActiveDevices = typename ActiveDevicesImpl::type; +/** + * @brief Priority trait for device selection. + */ +template struct DevicePriority { + static constexpr int value = 0; +}; + +template <> struct DevicePriority { + static constexpr int value = 1; +}; + +template <> struct DevicePriority { + static constexpr int value = 5; +}; + +template <> struct DevicePriority { + static constexpr int value = 5; +}; + +template +constexpr Device::Type ListGetBestDevice(ActiveDevices) { + static_assert(ListSize>::value > 0, "No devices enabled."); + return ListGetMax(ActiveDevices{}); +} + } // namespace infini::ccl #endif // INFINI_CCL_DEVICE_H_ diff --git a/src/operation.h b/src/operation.h new file mode 100644 index 0000000..30ed1ee --- /dev/null +++ b/src/operation.h @@ -0,0 +1,45 @@ +#ifndef INFINI_CCL_OPERATION_H_ +#define INFINI_CCL_OPERATION_H_ + +#include + +#include "backend.h" +#include "device.h" +#include "dispatcher.h" +#include "traits.h" + +namespace infini::ccl { + +template +class Operation { +public: + template static auto Call(Args &&...args) { + constexpr Device::Type kBestDev = + ListGetBestDevice(ActiveDevices{}); + constexpr BackendType kBestBack = + ListGetBestBackend(ActiveBackends{}); + + return Call(kBestBack, kBestDev, std::forward(args)...); + } + + template + static auto Call(BackendType backend, Device::Type device, Args &&...args) { + return DispatchFunc, ActiveDevices>( + {static_cast(backend), static_cast(device)}, + [&](auto resolved_list) { + constexpr BackendType kBackend = + static_cast(ListGet<0>(resolved_list)); + constexpr Device::Type kDevice = + static_cast(ListGet<1>(resolved_list)); + + return Key::template Execute( + std::forward(args)...); + }, + "Operation::Call"); + } +}; + +} // namespace infini::ccl + +#endif // INFINI_CCL_OPERATION_H_ diff --git a/src/traits.h b/src/traits.h index 68d6bbd..904f58c 100644 --- a/src/traits.h +++ b/src/traits.h @@ -100,6 +100,21 @@ struct Flatten { using type = typename Flatten, Rest...>::type; }; +// Generic recursion to find the "best" element based on a `PriorityTrait`. +template