diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 9e71785013..52a2371ed0 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -41,7 +41,7 @@ TensorDims get_piece_dims(ParallelTensorDims const &); TensorDims get_tensor_dims_unsafe(ParallelTensorDims const &); TensorDims get_reduced_dims(ParallelTensorDims const &); - +TensorDims get_per_device_dims(ParallelTensorDims const &dims); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index e23ae33cbf..93be4b230e 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -63,6 +63,8 @@ ParallelDim get_parallel_dim_at_idx(ParallelTensorShape const &shape, std::unordered_set get_parallel_tensor_dim_indices(ParallelTensorShape const &shape); +TensorShape get_per_device_shape(ParallelTensorShape const &s); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc index 9d02923689..ca7e417814 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary.cc @@ -35,7 +35,6 @@ ParallelTensorDimDegrees get_output_parallel_dim_degrees( ElementUnaryAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { ASSERT(input_degrees.sum_degree.value == 1); - ASSERT(input_degrees.discard_copy_degree.value == 1); return input_degrees; } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 71419e4a57..7798db0643 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -127,4 +127,12 @@ TensorDims get_reduced_dims(ParallelTensorDims const &dims) { return TensorDims{dim_sizes}; } +TensorDims get_per_device_dims(ParallelTensorDims const &dims) { + FFOrdered dim_sizes = + transform(dims.shard_dims, [](ShardParallelDim const &d) { + return positive_int{d.size.int_from_positive_int() / + d.degree.int_from_positive_int()}; + }); + return TensorDims{dim_sizes}; +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 91d3d0b1aa..f4480e3239 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -150,4 +150,11 @@ std::unordered_set return indices; } +// actual per-device allocation size +TensorShape get_per_device_shape(ParallelTensorShape const &s) { + return TensorShape{ + get_per_device_dims(s.dims), + s.data_type, + }; +} } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc index 672b160cbd..00df1fc0b9 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc @@ -61,14 +61,5 @@ TEST_SUITE(FF_TEST_SUITE) { make_input( SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p))); } - - SUBCASE("discard copy degree > 1") { - positive_int degree = 2_p; - - CHECK_THROWS(get_output_shape( - attrs, - make_input( - SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p))); - } } } diff --git a/lib/realm-execution/CMakeLists.txt b/lib/realm-execution/CMakeLists.txt index 25a51ada54..67c37b5823 100644 --- a/lib/realm-execution/CMakeLists.txt +++ b/lib/realm-execution/CMakeLists.txt @@ -1,13 +1,32 @@ -ff_add_library( - NAME - realm-execution - SRC_PATTERNS - src/*.cc - PUBLIC_INCLUDE +project(realm-execution + LANGUAGES CXX CUDA) + +file(GLOB_RECURSE SRC + CONFIGURE_DEPENDS + LIST_DIRECTORIES False + src/*.cc + src/**/*.cc + src/cuda/*.cu + src/**/*.cu +) + +add_library( + realm-execution + SHARED + ${SRC} +) + +target_include_directories( + realm-execution + PUBLIC include/ - PRIVATE_INCLUDE + PRIVATE src/ - DEPS +) + +target_link_libraries( + realm-execution + PUBLIC compiler kernels local-execution @@ -19,4 +38,13 @@ ff_add_library( realm ) +define_ff_vars(realm-execution) + +set_target_properties( + realm-execution + PROPERTIES + CUDA_STANDARD 17 +) + add_subdirectory(test) + diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index ab89e916c0..6bb38a0824 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -1,3 +1,4 @@ + #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_CONTEXT_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_CONTEXT_H @@ -15,6 +16,11 @@ namespace FlexFlow { +enum class CopyDomain { + SRC, // use src instance index space as copy domain (default) + DST, // use dst instance index space as copy domain +}; + /** * @brief An interface that wraps the rest of Realm and protects against certain * classes of bugs, such as shutdown bugs. @@ -63,17 +69,20 @@ struct RealmContext { int priority = 0); ///\} - /** \name Data movement */ + /** \name Data movement and reduction */ ///\{ - Realm::Event issue_copy(ParallelTensorShape const &src_shape, - Realm::RegionInstance src_inst, - ParallelTensorShape const &dst_shape, - Realm::RegionInstance dst_inst, - Realm::ProfilingRequestSet const &requests, - Realm::Event wait_on = Realm::Event::NO_EVENT, - int priority = 0); + Realm::Event + issue_copy(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0, + std::optional redop_id = std::nullopt, + bool exclusive = false, + CopyDomain domain = CopyDomain::SRC); ///\} - /** \name Instance management */ ///\{ std::pair @@ -88,6 +97,50 @@ struct RealmContext { */ Realm::Event get_outstanding_events(); + /** + * \brief Create a Realm region instance with an offset index space. + * + * Similar to \ref create_instance, but allocates the instance with a + * non-zero origin rect. This is used for sharded tensors where each + * shard occupies a sub-region of the full logical tensor's index space. + * + * For example, given a tensor of shape [10, 16] split along dim 0 + * with degree 2: + * - Shard 0 is allocated with rect [0..4, 0..15] + * - Shard 1 is allocated with rect [5..9, 0..15] + * + * This allows plain Realm copies between shards and the combined tensor + * to work correctly — points in each shard's index space match the + * corresponding points in the combined tensor's index space, so Realm + * copies data to the correct region without needing affine indirection. + * + * \param memory The Realm memory in which to allocate the instance. + * \param shape The per-device tensor shape (already divided by degree). + * Determines the size of the instance. + * \param offsets Per-dimension offsets into the full logical tensor. + * \p offsets[i] is the starting index along dimension i. + * For shard k along dim d with piece_size p: + * \p offsets[d] = k * p. + * \param prs Realm profiling request set. + * \param wait_on Event to wait on before creating the instance. + * \return A pair of the created \ref Realm::RegionInstance and a + * \ref Realm::Event that fires when the instance is ready. + * + * \note The instance's index space has origin at \p offsets, not at + * zero. Copies to/from this instance must use its actual index + * space (via \c get_indexspace()) rather than a reconstructed + * zero-based index space. + * + * \see create_instance + * \see perform_instance_allocation_for_value + */ + std::pair create_instance_with_offset( + Realm::Memory memory, + TensorShape const &shape, + std::vector const &offsets, + Realm::ProfilingRequestSet const &prs, + Realm::Event wait_on = Realm::Event::NO_EVENT); + protected: /** * \brief Compact **and clear** the outstanding event queue diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h new file mode 100644 index 0000000000..388b433947 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h @@ -0,0 +1,210 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H +#include "op-attrs/datatype.dtg.h" +#include +#include +namespace FlexFlow { + +/** + * \brief Realm Sum Reduction for Float + * \see https://legion.stanford.edu/tutorial/realm/reductions.html + */ +struct SumReductionFloat { + using LHS = float; + using RHS = float; + + /** \brief Identity element for addition (0.0) */ + static constexpr RHS identity = 0.0f; + + /** + * \brief Apply reduction: lhs += rhs + * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop + * \param lhs Left-hand side accumulator (modified in place) + * \param rhs Value to add + */ + template + REALM_CUDA_HD static void apply(LHS &lhs, RHS rhs) { + if (EXCLUSIVE) { + lhs += rhs; + } else { +#if defined(__CUDA_ARCH__) + atomicAdd(&lhs, rhs); +#else + union { + float f; + int i; + } old_val, new_val; + do { + old_val.f = lhs; + new_val.f = old_val.f + rhs; + } while ( + !__sync_bool_compare_and_swap((int *)&lhs, old_val.i, new_val.i)); +#endif + } + } + + template + __device__ static void apply_cuda(LHS &lhs, RHS rhs) { + apply(lhs, rhs); + } + + /** + * \brief Fold two RHS values: rhs1 += rhs2 + * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop + * \param rhs1 Accumulator (modified in place) + * \param rhs2 Value to fold in + */ + template + REALM_CUDA_HD static void fold(RHS &rhs1, RHS rhs2) { + if (EXCLUSIVE) { + rhs1 += rhs2; + } else { +#if defined(__CUDA_ARCH__) + atomicAdd(&rhs1, rhs2); +#else + union { + float f; + int i; + } old_val, new_val; + do { + old_val.f = rhs1; + new_val.f = old_val.f + rhs2; + } while ( + !__sync_bool_compare_and_swap((int *)&rhs1, old_val.i, new_val.i)); +#endif + } + } + template + __device__ static void fold_cuda(RHS &rhs1, RHS rhs2) { + fold(rhs1, rhs2); + } +}; + +/** + * \brief Realm Sum Reduction for Double + * \see https://legion.stanford.edu/tutorial/realm/reductions.html + */ +struct SumReductionDouble { + using LHS = double; + using RHS = double; + + /** \brief Identity element for addition (0.0) */ + static constexpr RHS identity = 0.0; + + /** + * \brief Apply reduction: lhs += rhs + * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop + * \param lhs Left-hand side accumulator (modified in place) + * \param rhs Value to add + */ + template + REALM_CUDA_HD static void apply(LHS &lhs, RHS rhs) { + if (EXCLUSIVE) { + lhs += rhs; + } else { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 + atomicAdd(&lhs, rhs); +#elif defined(__CUDA_ARCH__) + // pre-Pascal fallback CAS loop + unsigned long long int *addr = (unsigned long long int *)&lhs; + unsigned long long int old = *addr, assumed; + do { + assumed = old; + old = atomicCAS( + addr, + assumed, + __double_as_longlong(rhs + __longlong_as_double(assumed))); + } while (assumed != old); +#else + union { + double d; + long long i; + } old_val, new_val; + do { + old_val.d = lhs; + new_val.d = old_val.d + rhs; + } while (!__sync_bool_compare_and_swap( + (long long *)&lhs, old_val.i, new_val.i)); +#endif + } + } + template + __device__ static void apply_cuda(LHS &lhs, RHS rhs) { + apply(lhs, rhs); + } + + /** + * \brief Fold two RHS values: rhs1 += rhs2 + * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop + * \param rhs1 Accumulator (modified in place) + * \param rhs2 Value to fold in + */ + template + REALM_CUDA_HD static void fold(RHS &rhs1, RHS rhs2) { + if (EXCLUSIVE) { + rhs1 += rhs2; + } else { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 + atomicAdd(&rhs1, rhs2); +#elif defined(__CUDA_ARCH__) + unsigned long long int *addr = (unsigned long long int *)&rhs1; + unsigned long long int old = *addr, assumed; + do { + assumed = old; + old = atomicCAS( + addr, + assumed, + __double_as_longlong(rhs2 + __longlong_as_double(assumed))); + } while (assumed != old); +#else + union { + double d; + long long i; + } old_val, new_val; + do { + old_val.d = rhs1; + new_val.d = old_val.d + rhs2; + } while (!__sync_bool_compare_and_swap( + (long long *)&rhs1, old_val.i, new_val.i)); +#endif + } + } + + template + __device__ static void fold_cuda(RHS &rhs1, RHS rhs2) { + fold(rhs1, rhs2); + } +}; + +/** + * \brief Reduction op IDs for sum reductions + * \warning These IDs must not conflict with other registered reduction ops + */ +enum SumReductionOpIDs { + REDOP_SUM_FLOAT = 1, ///< Sum reduction op ID for float + REDOP_SUM_DOUBLE = 2, ///< Sum reduction op ID for double +}; + +/** + * \brief Returns the Realm reduction op ID for a sum reduction over the given datatype + * \param dtype The datatype to look up + * \return The corresponding Realm::ReductionOpID + * \throws PANIC if no sum reduction is registered for the given datatype + */ +inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { + switch (dtype) { + case DataType::FLOAT: + return REDOP_SUM_FLOAT; + case DataType::DOUBLE: + return REDOP_SUM_DOUBLE; + default: +#ifndef __CUDA_ARCH__ + throw std::runtime_error("no sum reduction registered for datatype"); +#else + assert(false); + return REDOP_SUM_FLOAT; //unreachable +#endif + } +} +} // namespace FlexFlow +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h b/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h index a956d53643..0c0b24c826 100644 --- a/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h @@ -27,7 +27,11 @@ namespace FlexFlow { * else Realm may not shut down properly. */ [[nodiscard]] Realm::Event register_all_tasks(); - +/** + * \brief Registers Realm sum reduction operators for supported data types. + * Defined in realm_reduction_cuda.cu — compiled with CUDA for GPU atomic support. + */ +void register_reductions(); } // namespace FlexFlow #endif diff --git a/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc index 1d517a8fe4..e7d8647b12 100644 --- a/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc @@ -31,6 +31,7 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( std::unordered_map *> device_state_map; + std::vector completion_events; for (DynamicNodeInvocation const &invocation : dg.invocations) { Realm::Processor target_proc = ctx.map_device_coord_to_processor( assert_unwrap(invocation.node_attrs.device_coord)); @@ -56,6 +57,7 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( precondition); if (completion_event.has_value()) { + completion_events.push_back(completion_event.value()); device_state_map.insert(std::pair{invocation, device_state_ptr}); } else { // Task doesn't require initialization, clean up and don't store result @@ -63,7 +65,9 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( } } - ctx.get_outstanding_events().wait(); + // wait for all init tasks — direct write to *result_ptr happens + // before each init task event fires so result is ready after this + Realm::Event::merge_events(completion_events).wait(); auto deref = [](DeviceSpecificPtr *const &p) { return *p; }; std::unordered_map> diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index 4ef2919b10..740e044579 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -1,6 +1,9 @@ #include "realm-execution/instance_allocation.h" #include "local-execution/tensor_allocation.h" +#include "op-attrs/num_ptensor_shard_dims_t.dtg.h" #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/relative_ff_dim_t.h" +#include "op-attrs/shard_parallel_dim.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include "realm-execution/realm_context.h" #include "realm-execution/tensor_instance_backing.h" @@ -17,10 +20,10 @@ #include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" #include "utils/exception.h" +#include "utils/nonnegative_int/nonnegative_int.h" #include "utils/optional.h" namespace FlexFlow { - std::pair perform_instance_allocation_for_value( MachineSpaceCoordinate const &device_coord, @@ -28,11 +31,51 @@ std::pair RealmContext &ctx) { ASSERT(value.accessor == std::nullopt); - TensorShape shape = get_piece_shape(value.parallel_tensor_shape.value()); + ParallelTensorShape const par_shape = value.parallel_tensor_shape.value(); + + TensorShape shape = get_per_device_shape(par_shape); Realm::Processor proc = ctx.map_device_coord_to_processor(device_coord); Realm::Memory memory = ctx.get_nearest_memory(proc); - return ctx.create_instance(memory, shape, Realm::ProfilingRequestSet()); + + int ndims = static_cast(num_shard_dims(par_shape).value); + std::vector offsets(ndims, 0); + + if (value.shard_coord.has_value()) { + ParallelTensorSpaceCoordinate const &coord = value.shard_coord.value(); + + for (int i = 0; i < ndims; i++) { + relative_ff_dim_t rel_dim{i}; + + // skip if shard_components doesn't have this dim + if (!coord.shard_components.idx_is_valid(rel_dim)) { + continue; + } + + ShardParallelDim shard_dim = par_shape.dims.shard_dims.at(rel_dim); + + // skip if not actually sharded + if (shard_dim.degree == 1_p) { + continue; + } + + nonnegative_int piece_size = + shard_dim.size.nonnegative_int_from_positive_int() / + shard_dim.degree.nonnegative_int_from_positive_int(); + nonnegative_int shard_idx = coord.shard_components.at(rel_dim); + offsets[i] = static_cast(shard_idx * piece_size); + } + } + + bool has_offset = + std::any_of(offsets.begin(), offsets.end(), [](int o) { return o != 0; }); + + if (has_offset) { + return ctx.create_instance_with_offset( + memory, shape, offsets, Realm::ProfilingRequestSet()); + } else { + return ctx.create_instance(memory, shape, Realm::ProfilingRequestSet()); + } } TensorInstanceBacking perform_instance_allocation( diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 0ecd02143e..06823ad089 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -1,4 +1,5 @@ #include "realm-execution/pcg_instance.h" +#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_slot_name.dtg.h" #include "pcg/optimizer_attrs.h" #include "realm-execution/dependency_set.h" @@ -6,6 +7,7 @@ #include "realm-execution/instance_allocation.h" #include "realm-execution/realm_context.h" #include "realm-execution/tasks/impl/op_task.h" +#include "realm-execution/tasks/realm_reduction.h" #include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/copy_insertion.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" @@ -77,6 +79,10 @@ std::optional return this->logit_grad_tensor; } +static bool has_task_type(DynamicNodeAttrs const &n, DynamicTaskType t) { + return n.task_type.has_value() && n.task_type.value() == t; +} + PCGInstance create_pcg_instance( RealmContext &ctx, MappedParallelComputationGraph const &mpcg, @@ -215,6 +221,148 @@ static Realm::Event spawn_dynamic_node_invocation( precondition); }; + auto issue_sum_reduction_copy = + [&](DynamicValueAttrs const &input, + DynamicValueAttrs const &output) -> Realm::Event { + Realm::RegionInstance src_inst = + tensor_instance_backing.backing.at(input).first; + Realm::RegionInstance dst_inst = + tensor_instance_backing.backing.at(output).first; + + Realm::ReductionOpID redop_id = get_sum_reduction_op_id( + assert_unwrap(input.parallel_tensor_shape).data_type); + + return ctx.issue_copy(assert_unwrap(input.parallel_tensor_shape), + src_inst, + assert_unwrap(output.parallel_tensor_shape), + dst_inst, + Realm::ProfilingRequestSet{}, + precondition, + /*priority=*/0, + /*redop_id=*/redop_id, + /*exclusive=*/false); + }; + + // replicate backward — find GRADIENT slot, chain reductions sequentially + auto issue_replicate_bwd = [&]() { + std::optional output_grad_opt; + for (auto const &[slot, value] : invocation.inputs) { + if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) { + output_grad_opt = value; + } + } + DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt); + DynamicValueAttrs input_grad = get_only(invocation.outputs).second; + + // chain sequentially to avoid write races + Realm::Event e = precondition; + for (auto const &[p, m] : assert_unwrap(output_grad.mapping)) { + DynamicValueAttrs replica_key = output_grad; + replica_key.mapping = + bidict{{p, m}}; + replica_key.shard_coord = p; + e = issue_sum_reduction_copy(replica_key, input_grad); + } + return e; + }; + + auto issue_reduction_fwd = [&]() { + DynamicValueAttrs const &output = get_only(invocation.outputs).second; + Realm::RegionInstance dst_inst = + tensor_instance_backing.backing.at(output).first; + + Realm::ReductionOpID redop_id = get_sum_reduction_op_id( + assert_unwrap(output.parallel_tensor_shape).data_type); + + // chain reductions sequentially + Realm::Event e = precondition; + for (auto const &[slot, input] : invocation.inputs) { + Realm::RegionInstance src_inst = + tensor_instance_backing.backing.at(input).first; + e = ctx.issue_copy(assert_unwrap(input.parallel_tensor_shape), + src_inst, + assert_unwrap(output.parallel_tensor_shape), + dst_inst, + Realm::ProfilingRequestSet{}, + e, + /*priority=*/0, + /*redop_id=*/redop_id, + /*exclusive=*/false); + } + return e; + }; + auto issue_combine_fwd = [&]() { + DynamicValueAttrs const &output = get_only(invocation.outputs).second; + Realm::RegionInstance dst_inst = + tensor_instance_backing.backing.at(output).first; + + // chain copies sequentially — each input shard copies into the output + Realm::Event e = precondition; + for (auto const &[slot, input] : invocation.inputs) { + Realm::RegionInstance src_inst = + tensor_instance_backing.backing.at(input).first; + e = ctx.issue_copy(assert_unwrap(input.parallel_tensor_shape), + src_inst, + assert_unwrap(output.parallel_tensor_shape), + dst_inst, + Realm::ProfilingRequestSet{}, + e); + } + return e; + }; + + auto issue_parallel_op_bwd_copy = [&]() { + // find single GRADIENT input + std::optional grad_input_opt; + for (auto const &[slot, value] : invocation.inputs) { + if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) { + grad_input_opt = value; + } + } + + // determine copy domain based on op type + PCGOperatorAttrs pcg = + invocation.node_attrs.op_attrs.value().get(); + CopyDomain domain = CopyDomain::SRC; + // reduction BWD: same size → use SRC domain + if (pcg.has()) { + // repartition BWD: src=small shard, dst=full → use SRC domain + domain = CopyDomain::SRC; + } else if (pcg.has()) { + // combine BWD: src=full, dst=small shard → use DST domain + domain = CopyDomain::DST; + } + DynamicValueAttrs grad_input = assert_unwrap(grad_input_opt); + DynamicValueAttrs output = get_only(invocation.outputs).second; + Realm::RegionInstance dst_inst = + tensor_instance_backing.backing.at(output).first; + + // iterate over all source coords in grad mapping + // chain copies sequentially into the same destination + Realm::Event e = precondition; + for (auto const &[p, m] : assert_unwrap(grad_input.mapping)) { + DynamicValueAttrs shard_key = grad_input; + shard_key.mapping = + bidict{{p, m}}; + shard_key.shard_coord = p; + + Realm::RegionInstance src_inst = + tensor_instance_backing.backing.at(shard_key).first; + + e = ctx.issue_copy(assert_unwrap(grad_input.parallel_tensor_shape), + src_inst, + assert_unwrap(output.parallel_tensor_shape), + dst_inst, + Realm::ProfilingRequestSet{}, + e, + /*priority=*/0, + /*redop_id=*/std::nullopt, + /*exclusive=*/false, + /*domain=*/domain); + } + return e; + }; + TrainingOperationAttrs op_attrs = assert_unwrap(invocation.node_attrs.op_attrs); return op_attrs.visit(overload{ @@ -222,6 +370,47 @@ static Realm::Event spawn_dynamic_node_invocation( return pcg_op_attrs.visit(overload{ [&](InputAttrs const &) { return Realm::Event::NO_EVENT; }, [&](WeightAttrs const &) { return Realm::Event::NO_EVENT; }, + [&](ReplicateAttrs const &) { + if (invocation.node_attrs.task_type.has_value() && + invocation.node_attrs.task_type.value() == + DynamicTaskType::BWD) { + return issue_replicate_bwd(); + } + return issue_copy(); // forward + }, + [&](RepartitionAttrs const &) { + if (has_task_type(invocation.node_attrs, DynamicTaskType::BWD)) { + return issue_parallel_op_bwd_copy(); // point-to-point copy after shard expansion + } + // FWD: src=[0..9], dst=[0..4] or [5..9] — use DST domain + DynamicValueAttrs const &input = + get_only(invocation.inputs).second; + DynamicValueAttrs const &output = + get_only(invocation.outputs).second; + return ctx.issue_copy( + assert_unwrap(input.parallel_tensor_shape), + tensor_instance_backing.backing.at(input).first, + assert_unwrap(output.parallel_tensor_shape), + tensor_instance_backing.backing.at(output).first, + Realm::ProfilingRequestSet{}, + precondition, + /*priority=*/0, + /*redop_id=*/std::nullopt, + /*exclusive=*/false, + /*domain=*/CopyDomain::DST); // ← use dst index space + }, + [&](CombineAttrs const &) { + if (has_task_type(invocation.node_attrs, DynamicTaskType::BWD)) { + return issue_parallel_op_bwd_copy(); // point-to-point copy after shard expansion + } + return issue_combine_fwd(); // forward + }, + [&](ReductionAttrs const &) { + if (has_task_type(invocation.node_attrs, DynamicTaskType::BWD)) { + return issue_parallel_op_bwd_copy(); // broadcast copy after shard expansion + } + return issue_reduction_fwd(); // forward needs sum reduction + }, [&](auto const &) { return spawn_task(); }, }); }, diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 790c1bd613..98ec711310 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -15,8 +15,28 @@ #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/one_to_many/one_to_many.h" #include "utils/positive_int/positive_int.h" +#include namespace FlexFlow { +template +static Realm::Rect + rect_from_dims_with_offset(TensorDims const &dims, + std::vector const &offsets) { + std::vector values; + for (positive_int const &v : dims.ff_ordered) { + values.push_back(v.int_from_positive_int()); + } + ASSERT((int)values.size() == N); + ASSERT((int)offsets.size() == N); + + std::vector lo(N), hi(N); + for (int i = 0; i < N; i++) { + lo[i] = offsets[i]; + hi[i] = offsets[i] + values[i] - 1; + } + return Realm::Rect{Realm::Point{lo.data()}, + Realm::Point{hi.data()}}; +} RealmContext::RealmContext(Realm::Processor processor) : processor(processor), @@ -161,7 +181,10 @@ Realm::Event Realm::RegionInstance dst_inst, Realm::ProfilingRequestSet const &requests, Realm::Event wait_on, - int priority) { + int priority, + std::optional redop_id, + bool exclusive, + CopyDomain domain) { TensorShape src_piece_shape = get_piece_shape(src_shape); TensorShape dst_piece_shape = get_piece_shape(dst_shape); ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match @@ -183,36 +206,45 @@ Realm::Event size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), /*subfield_offset=*/0); + // set reduction op on dst field if provided + if (redop_id.has_value()) { + dst_field.set_redop(redop_id.value(), /*is_fold=*/false, exclusive); + } + + // select which instance's index space to use as copy domain + Realm::RegionInstance const domain_inst = + (domain == CopyDomain::DST) ? dst_inst : src_inst; + Realm::Event result; switch (src_piece_shape.dims.ff_ordered.num_dims()) { #if REALM_MAX_DIM >= 1 case 1: - result = ispace_from_dims<1>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); + result = domain_inst.get_indexspace<1, int>().copy( + {src_field}, {dst_field}, requests, wait_on, priority); break; #endif #if REALM_MAX_DIM >= 2 case 2: - result = ispace_from_dims<2>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); + result = domain_inst.get_indexspace<2, int>().copy( + {src_field}, {dst_field}, requests, wait_on, priority); break; #endif #if REALM_MAX_DIM >= 3 case 3: - result = ispace_from_dims<3>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); + result = domain_inst.get_indexspace<3, int>().copy( + {src_field}, {dst_field}, requests, wait_on, priority); break; #endif #if REALM_MAX_DIM >= 4 case 4: - result = ispace_from_dims<4>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); + result = domain_inst.get_indexspace<4, int>().copy( + {src_field}, {dst_field}, requests, wait_on, priority); break; #endif #if REALM_MAX_DIM >= 5 case 5: - result = ispace_from_dims<5>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); + result = domain_inst.get_indexspace<5, int>().copy( + {src_field}, {dst_field}, requests, wait_on, priority); break; #endif default: @@ -223,7 +255,6 @@ Realm::Event this->outstanding_events.push_back(result); return result; } - std::pair RealmContext::create_instance(Realm::Memory memory, TensorShape const &shape, @@ -303,6 +334,86 @@ std::pair return std::pair{inst, ready}; } +std::pair + RealmContext::create_instance_with_offset( + Realm::Memory memory, + TensorShape const &shape, + std::vector const &offsets, + Realm::ProfilingRequestSet const &prs, + Realm::Event wait_on) { + std::vector field_sizes{static_cast( + size_of_datatype(shape.data_type).int_from_positive_int())}; + Realm::RegionInstance inst; + Realm::Event ready; + switch (shape.dims.ff_ordered.num_dims()) { +#if REALM_MAX_DIM >= 1 + case 1: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + rect_from_dims_with_offset<1>(shape.dims, offsets), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); + break; +#endif +#if REALM_MAX_DIM >= 2 + case 2: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + rect_from_dims_with_offset<2>(shape.dims, offsets), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); + break; +#endif +#if REALM_MAX_DIM >= 3 + case 3: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + rect_from_dims_with_offset<3>(shape.dims, offsets), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); + break; +#endif +#if REALM_MAX_DIM >= 4 + case 4: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + rect_from_dims_with_offset<4>(shape.dims, offsets), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); + break; +#endif +#if REALM_MAX_DIM >= 5 + case 5: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + rect_from_dims_with_offset<5>(shape.dims, offsets), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); + break; +#endif + default: + PANIC("TensorShape dims greater than REALM_MAX_DIM: {}", + shape.dims.ff_ordered.num_dims()); + } + this->outstanding_events.push_back(ready); + return {inst, ready}; +} + Realm::Event RealmContext::get_outstanding_events() { Realm::Event result = this->merge_outstanding_events(); this->outstanding_events.push_back(result); diff --git a/lib/realm-execution/src/realm-execution/tasks/cuda/realm_reduction.cu b/lib/realm-execution/src/realm-execution/tasks/cuda/realm_reduction.cu new file mode 100644 index 0000000000..7755490128 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/cuda/realm_reduction.cu @@ -0,0 +1,31 @@ +// realm_reduction_cuda.cu +#include "realm-execution/tasks/realm_reduction.h" +#include +#include +#include + +namespace FlexFlow { + +void register_reductions() { + ::Realm::Runtime rt = ::Realm::Runtime::get_runtime(); + + // register SumReductionFloat with CUDA kernels + { + ::Realm::ReductionOpUntyped *redop = + ::Realm::ReductionOpUntyped::create_reduction_op(); + ::Realm::Cuda::add_cuda_redop_kernels(redop); + bool ok = rt.register_reduction(REDOP_SUM_FLOAT, redop); + assert(ok && "Failed to register SumReductionFloat"); + } + + // register SumReductionDouble with CUDA kernels + { + ::Realm::ReductionOpUntyped *redop = + ::Realm::ReductionOpUntyped::create_reduction_op(); + ::Realm::Cuda::add_cuda_redop_kernels(redop); + bool ok = rt.register_reduction(REDOP_SUM_DOUBLE, redop); + assert(ok && "Failed to register SumReductionDouble"); + } +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc index 753fccf74b..0ea51810e4 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc @@ -66,11 +66,17 @@ void per_device_op_state_init_task_body(void const *args, result_state, ctx.get_current_device_idx())}; DeviceSpecificPtr result_device_specific{ ctx.get_current_device_idx(), result_state_ptr}; - spawn_per_device_op_state_init_return_task(ctx, - task_args.origin_proc, - result_device_specific, - task_args.origin_result_ptr, - Realm::Event::NO_EVENT); + + // replace spawn_per_device_op_state_init_return_task with: + // NOTE: SM/TODO: direct write assumes single-node shared address space + // For multi-node, replace with UserEvent trigger pattern + *task_args.origin_result_ptr = result_device_specific; + + // spawn_per_device_op_state_init_return_task(ctx, + // task_args.origin_proc, + // result_device_specific, + // task_args.origin_result_ptr, + // Realm::Event::NO_EVENT); } std::optional spawn_per_device_op_state_init_task( diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index e7a8948f8d..df004146d4 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -5,6 +5,7 @@ #include "realm-execution/tasks/impl/op_task.h" #include "realm-execution/tasks/impl/per_device_op_state_init_return_task.h" #include "realm-execution/tasks/impl/per_device_op_state_init_task.h" +#include "realm-execution/tasks/realm_reduction.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/exception.h" @@ -33,6 +34,7 @@ Realm::Event register_task(Realm::Processor::Kind target_kind, Realm::Event register_all_tasks() { std::vector pending_registrations; + register_reductions(); std::vector init_task_ids = { // Init tasks task_id_t::BATCHNORM_INIT_TASK_ID, diff --git a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc index dd4b0a66ca..e55eebaabd 100644 --- a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc @@ -36,7 +36,7 @@ std::optional [](BatchNormAttrs const &) { return task_id_t::BATCHNORM_INIT_TASK_ID; }, [](BroadcastAttrs const &) { return std::nullopt; }, [](CastAttrs const &) { return std::nullopt; }, - [](CombineAttrs const &attrs) { return task_id_t::COMBINE_INIT_TASK_ID; }, + [](CombineAttrs const &attrs) { return std::nullopt; }, [](ConcatAttrs const &) { return std::nullopt; }, [](Conv2DAttrs const &) { return task_id_t::CONV2D_INIT_TASK_ID; }, [](DropoutAttrs const &) { return task_id_t::DROPOUT_INIT_TASK_ID; }, @@ -58,15 +58,9 @@ std::optional [](NoopAttrs const &) { return std::nullopt; }, [](Pool2DAttrs const &) { return task_id_t::POOL2D_INIT_TASK_ID; }, [](ReduceAttrs const &) { return task_id_t::REDUCE_INIT_TASK_ID; }, - [](ReductionAttrs const &attrs) { - return task_id_t::REDUCTION_INIT_TASK_ID; - }, - [](RepartitionAttrs const &attrs) { - return task_id_t::REPARTITION_INIT_TASK_ID; - }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_INIT_TASK_ID; - }, + [](ReductionAttrs const &attrs) { return std::nullopt; }, + [](RepartitionAttrs const &attrs) { return std::nullopt; }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return std::nullopt; }, [](ReverseAttrs const &) { return std::nullopt; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_INIT_TASK_ID; }, @@ -87,7 +81,7 @@ std::optional [](BatchNormAttrs const &) { return task_id_t::BATCHNORM_FWD_TASK_ID; }, [](BroadcastAttrs const &) { return task_id_t::BROADCAST_FWD_TASK_ID; }, [](CastAttrs const &) { return task_id_t::CAST_FWD_TASK_ID; }, - [](CombineAttrs const &attrs) { return task_id_t::COMBINE_FWD_TASK_ID; }, + [](CombineAttrs const &attrs) { return std::nullopt; }, [](ConcatAttrs const &) { return task_id_t::CONCAT_FWD_TASK_ID; }, [](Conv2DAttrs const &) { return task_id_t::CONV2D_FWD_TASK_ID; }, [](DropoutAttrs const &) { return task_id_t::DROPOUT_FWD_TASK_ID; }, @@ -109,15 +103,9 @@ std::optional [](NoopAttrs const &) { return std::nullopt; }, [](Pool2DAttrs const &) { return task_id_t::POOL2D_FWD_TASK_ID; }, [](ReduceAttrs const &) { return task_id_t::REDUCE_FWD_TASK_ID; }, - [](ReductionAttrs const &attrs) { - return task_id_t::REDUCTION_FWD_TASK_ID; - }, - [](RepartitionAttrs const &attrs) { - return task_id_t::REPARTITION_FWD_TASK_ID; - }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_FWD_TASK_ID; - }, + [](ReductionAttrs const &attrs) { return std::nullopt; }, + [](RepartitionAttrs const &attrs) { return std::nullopt; }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return task_id_t::RESHAPE_FWD_TASK_ID; }, [](ReverseAttrs const &) { return task_id_t::REVERSE_FWD_TASK_ID; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_FWD_TASK_ID; }, @@ -138,7 +126,7 @@ std::optional [](BatchNormAttrs const &) { return task_id_t::BATCHNORM_BWD_TASK_ID; }, [](BroadcastAttrs const &) { return task_id_t::BROADCAST_BWD_TASK_ID; }, [](CastAttrs const &) { return task_id_t::CAST_BWD_TASK_ID; }, - [](CombineAttrs const &attrs) { return task_id_t::COMBINE_BWD_TASK_ID; }, + [](CombineAttrs const &attrs) { return std::nullopt; }, [](ConcatAttrs const &) { return task_id_t::CONCAT_BWD_TASK_ID; }, [](Conv2DAttrs const &) { return task_id_t::CONV2D_BWD_TASK_ID; }, [](DropoutAttrs const &) { return task_id_t::DROPOUT_BWD_TASK_ID; }, @@ -160,15 +148,9 @@ std::optional [](NoopAttrs const &) { return std::nullopt; }, [](Pool2DAttrs const &) { return task_id_t::POOL2D_BWD_TASK_ID; }, [](ReduceAttrs const &) { return task_id_t::REDUCE_BWD_TASK_ID; }, - [](ReductionAttrs const &attrs) { - return task_id_t::REDUCTION_BWD_TASK_ID; - }, - [](RepartitionAttrs const &attrs) { - return task_id_t::REPARTITION_BWD_TASK_ID; - }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_BWD_TASK_ID; - }, + [](ReductionAttrs const &attrs) { return std::nullopt; }, + [](RepartitionAttrs const &attrs) { return std::nullopt; }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return task_id_t::RESHAPE_BWD_TASK_ID; }, [](ReverseAttrs const &) { return task_id_t::REVERSE_BWD_TASK_ID; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_BWD_TASK_ID; }, diff --git a/lib/realm-execution/test/src/realm-execution/test_op_combine.cc b/lib/realm-execution/test/src/realm-execution/test_op_combine.cc new file mode 100644 index 0000000000..47e5ea8175 --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/test_op_combine.cc @@ -0,0 +1,352 @@ +#include "internal/realm_test_utils.h" +#include "kernels/allocation.h" +#include "kernels/compare_tensor_accessors.h" +#include "kernels/copy_tensor_accessor.h" +#include "kernels/format_accessor_contents.h" +#include "kernels/tensor_accessor_reductions.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "realm-execution/distributed_ff_handle.h" +#include "realm-execution/dynamic_tensor_accessor_from_instance.h" +#include "realm-execution/pcg_instance.h" +#include "realm-execution/realm_context.h" +#include "realm-execution/realm_manager.h" +#include "task-spec/permissions.h" +#include "test/utils/doctest/check_kv.h" +#include "utils/containers/require_only_key.h" +#include + +namespace test { + +using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; + +template +static ParallelLayerAttrs make_layer_attrs(T const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + /*name=*/std::nullopt, + }; +}; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Combine Op (CPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/2_p, /*num_gpus=*/0_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // input layer + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // repartition along dim 0 with degree 2 + // needed so combine has a degree=2 sharded tensor to combine + RepartitionAttrs repartition_attrs{ + /*repartition_dim=*/ff_dim_t{nonnegative_int{0}}, + /*repartition_degree=*/2_p, + }; + ParallelLayerAddedResult repartition_operator = + add_parallel_layer(pcg, + make_layer_attrs(repartition_attrs), + {{TensorSlotName::INPUT, t_input}}, + /*weights=*/{}); + parallel_tensor_guid_t t_repartitioned = require_only_key( + repartition_operator.outputs, TensorSlotName::OUTPUT); + + // combine along dim 0 with degree 2 + CombineAttrs combine_attrs{ + /*combine_dim=*/ff_dim_t{nonnegative_int{0}}, + /*combine_degree=*/2_p, + }; + ParallelLayerAddedResult combine_operator = + add_parallel_layer(pcg, + make_layer_attrs(combine_attrs), + {{TensorSlotName::INPUT, t_repartitioned}}, + /*weights=*/{}); + parallel_tensor_guid_t t_combined = require_only_key( + combine_operator.outputs, TensorSlotName::OUTPUT); + + // relu consumer + ParallelLayerAddedResult relu_operator = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + {{TensorSlotName::INPUT, t_combined}}, + /*weights=*/{}); + + MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; + MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; + + // input: one shard on cpu0 (not yet repartitioned) + ParallelTensorSpaceCoordinate tensor_coord0{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + // after repartition: two shards along dim 0 + ParallelTensorSpaceCoordinate tensor_coord_shard0{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + ParallelTensorSpaceCoordinate tensor_coord_shard1{ + 0_n, 0_n, FFOrdered{1_n, 0_n}}; + // after combine: one shard on cpu0 + ParallelTensorSpaceCoordinate tensor_coord_combined{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + + MappedParallelComputationGraph mpcg{ + pcg, + { + // input: one shard on cpu0 + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + // repartition: OUTPUT only — no INPUT since all replicas + // read same source coord violating bidict uniqueness + {repartition_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord_shard0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord_shard1}, + }}}, + }}}, + // combine: two inputs → one output on cpu0 + {combine_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord_shard0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord_shard1}, + }}}, + }}}, + // relu: one shard on cpu0 + {relu_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord_combined}, + {TensorSlotName::OUTPUT, tensor_coord_combined}, + }}}, + }}}, + }}; + + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = + create_pcg_instance(ctx, + mpcg, + optimizer_attrs, + std::nullopt, + input_tensors, + ProfilingSettings{0, 0}, + device_handle, + FFIterationConfig{1_p}); + + perform_all_passes_for_pcg_instance(pcg_instance, + ProfilingSettings{0, 0}, + device_handle, + FFIterationConfig{1_p}); + }); + result.wait(); + } +} + +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Combine Op (GPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/2_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // input layer + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // repartition along dim 0 with degree 2 + // needed so combine has a degree=2 sharded tensor to combine + RepartitionAttrs repartition_attrs{ + /*repartition_dim=*/ff_dim_t{nonnegative_int{0}}, + /*repartition_degree=*/2_p, + }; + ParallelLayerAddedResult repartition_operator = + add_parallel_layer(pcg, + make_layer_attrs(repartition_attrs), + {{TensorSlotName::INPUT, t_input}}, + /*weights=*/{}); + parallel_tensor_guid_t t_repartitioned = require_only_key( + repartition_operator.outputs, TensorSlotName::OUTPUT); + + // combine along dim 0 with degree 2 + CombineAttrs combine_attrs{ + /*combine_dim=*/ff_dim_t{nonnegative_int{0}}, + /*combine_degree=*/2_p, + }; + ParallelLayerAddedResult combine_operator = + add_parallel_layer(pcg, + make_layer_attrs(combine_attrs), + {{TensorSlotName::INPUT, t_repartitioned}}, + /*weights=*/{}); + parallel_tensor_guid_t t_combined = require_only_key( + combine_operator.outputs, TensorSlotName::OUTPUT); + + // relu consumer + ParallelLayerAddedResult relu_operator = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + {{TensorSlotName::INPUT, t_combined}}, + /*weights=*/{}); + + MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; + MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; + + // input: one shard on gpu0 (not yet repartitioned) + ParallelTensorSpaceCoordinate tensor_coord0{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + // after repartition: two shards along dim 0 + ParallelTensorSpaceCoordinate tensor_coord_shard0{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + ParallelTensorSpaceCoordinate tensor_coord_shard1{ + 0_n, 0_n, FFOrdered{1_n, 0_n}}; + // after combine: one shard on gpu0 + ParallelTensorSpaceCoordinate tensor_coord_combined{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + + MappedParallelComputationGraph mpcg{ + pcg, + { + // input: one shard on gpu0 + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + // repartition: OUTPUT only — no INPUT since all replicas + // read same source coord violating bidict uniqueness + {repartition_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord_shard0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord_shard1}, + }}}, + }}}, + // combine: two inputs → one output on gpu0 + {combine_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord_shard0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord_shard1}, + }}}, + }}}, + // relu: one shard on gpu0 + {relu_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord_combined}, + {TensorSlotName::OUTPUT, tensor_coord_combined}, + }}}, + }}}, + }}; + + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = + create_pcg_instance(ctx, + mpcg, + optimizer_attrs, + std::nullopt, + input_tensors, + ProfilingSettings{0, 0}, + device_handle, + FFIterationConfig{1_p}); + + perform_all_passes_for_pcg_instance(pcg_instance, + ProfilingSettings{0, 0}, + device_handle, + FFIterationConfig{1_p}); + }); + result.wait(); + } +} +} // namespace test diff --git a/lib/realm-execution/test/src/realm-execution/test_op_reduce.cc b/lib/realm-execution/test/src/realm-execution/test_op_reduce.cc new file mode 100644 index 0000000000..f472ccb96b --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/test_op_reduce.cc @@ -0,0 +1,530 @@ +#include "internal/realm_test_utils.h" +#include "kernels/allocation.h" +#include "kernels/compare_tensor_accessors.h" +#include "kernels/copy_tensor_accessor.h" +#include "kernels/format_accessor_contents.h" +#include "kernels/tensor_accessor_reductions.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "realm-execution/distributed_ff_handle.h" +#include "realm-execution/dynamic_tensor_accessor_from_instance.h" +#include "realm-execution/pcg_instance.h" +#include "realm-execution/realm_context.h" +#include "realm-execution/realm_manager.h" +#include "task-spec/permissions.h" +#include "test/utils/doctest/check_kv.h" +#include "utils/containers/require_only_key.h" +#include + +namespace test { + +using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; + +template +static ParallelLayerAttrs make_layer_attrs(T const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + /*name=*/std::nullopt, + }; +}; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Reduction Op (CPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/2_p, /*num_gpus=*/0_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + ControllerTaskResult result = manager.start_controller([](RealmContext + &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 4_p; + positive_int in_channels = 8_p; + positive_int out_channels = 4_p; + + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, in_channels}}, DataType::FLOAT}; + + TensorShape weight_tensor_shape = TensorShape{ + TensorDims{FFOrdered{out_channels, in_channels}}, DataType::FLOAT}; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // input layer + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // weight layer + ParallelLayerAddedResult weights_layer = + pcg_add_input_layer(pcg, weight_tensor_shape); + parallel_tensor_guid_t t_weight = + require_only_key(weights_layer.outputs, TensorSlotName::OUTPUT); + + // repartition input along feature dim (dim 1) with degree 2 + RepartitionAttrs input_repartition_attrs{ + /*repartition_dim=*/ff_dim_t{nonnegative_int{1}}, + /*repartition_degree=*/2_p, + }; + ParallelLayerAddedResult input_repartition_operator = + add_parallel_layer(pcg, + make_layer_attrs(input_repartition_attrs), + {{TensorSlotName::INPUT, t_input}}, + /*weights=*/{}); + parallel_tensor_guid_t t_input_repartitioned = require_only_key( + input_repartition_operator.outputs, TensorSlotName::OUTPUT); + + // repartition weight along feature dim (dim 1) with degree 2 + // to match the repartitioned input + RepartitionAttrs weight_repartition_attrs{ + /*repartition_dim=*/ff_dim_t{nonnegative_int{1}}, + /*repartition_degree=*/2_p, + }; + ParallelLayerAddedResult weight_repartition_operator = + add_parallel_layer(pcg, + make_layer_attrs(weight_repartition_attrs), + {{TensorSlotName::INPUT, t_weight}}, + /*weights=*/{}); + parallel_tensor_guid_t t_weight_repartitioned = require_only_key( + weight_repartition_operator.outputs, TensorSlotName::OUTPUT); + + // linear with repartitioned input and weight + // shard_dim[-1]=2 → sum_degree=2 output + ParallelLayerAddedResult linear_operator = add_parallel_layer( + pcg, + ParallelLayerAttrs{PCGOperatorAttrs{LinearAttrs{out_channels, + /*use_bias=*/false, + DataType::FLOAT, + Activation::RELU, + std::nullopt}}, + std::nullopt}, + /*inputs=*/ + { + {TensorSlotName::INPUT, t_input_repartitioned}, + }, + /*weights=*/ + { + {TensorSlotName::WEIGHT, t_weight_repartitioned}, + }); + parallel_tensor_guid_t t_linear = + require_only_key(linear_operator.outputs, TensorSlotName::OUTPUT); + + // reduction degree=2 — sums partial results + ReductionAttrs reduction_attrs{/*reduction_degree=*/2_p}; + ParallelLayerAddedResult reduction_operator = + add_parallel_layer(pcg, + make_layer_attrs(reduction_attrs), + {{TensorSlotName::INPUT, t_linear}}, + /*weights=*/{}); + parallel_tensor_guid_t t_reduced = + require_only_key(reduction_operator.outputs, TensorSlotName::OUTPUT); + + // relu consumer + ParallelLayerAddedResult relu_operator = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + {{TensorSlotName::INPUT, t_reduced}}, + /*weights=*/{}); + + MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; + MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; + + // input: unsharded on cpu0 — 2 shard dims + ParallelTensorSpaceCoordinate input_coord{0_n, 0_n, FFOrdered{0_n, 0_n}}; + + // weight: unsharded on cpu0 — 2 shard dims + ParallelTensorSpaceCoordinate weight_coord{0_n, 0_n, FFOrdered{0_n, 0_n}}; + + // after repartition: input sharded along feature dim + ParallelTensorSpaceCoordinate input_repartitioned_coord_0{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + ParallelTensorSpaceCoordinate input_repartitioned_coord_1{ + 0_n, 0_n, FFOrdered{0_n, 1_n}}; + + // after repartition: weight sharded along feature dim + ParallelTensorSpaceCoordinate weight_repartitioned_coord_0{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + ParallelTensorSpaceCoordinate weight_repartitioned_coord_1{ + 0_n, 0_n, FFOrdered{0_n, 1_n}}; + + // linear output: partial sums — sum_component distinguishes them + // output has 2 shard dims [{4,1},{4,1}] + ParallelTensorSpaceCoordinate linear_coord_0{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + ParallelTensorSpaceCoordinate linear_coord_1{ + 1_n, 0_n, FFOrdered{0_n, 0_n}}; + + // reduced output: fully reduced on cpu0 + ParallelTensorSpaceCoordinate reduced_coord{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + + MappedParallelComputationGraph mpcg{ + pcg, + { + // input: unsharded on cpu0 + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, input_coord}}}}}}}, + // weight: unsharded on cpu0 + {weights_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, weight_coord}}}}}}}, + // input repartition: OUTPUT only + {input_repartition_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, input_repartitioned_coord_0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, input_repartitioned_coord_1}, + }}}, + }}}, + // weight repartition: OUTPUT only + {weight_repartition_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, weight_repartitioned_coord_0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, weight_repartitioned_coord_1}, + }}}, + }}}, + // linear: INPUT + WEIGHT + OUTPUT per device + {linear_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, input_repartitioned_coord_0}, + {TensorSlotName::WEIGHT, weight_repartitioned_coord_0}, + {TensorSlotName::OUTPUT, linear_coord_0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, input_repartitioned_coord_1}, + {TensorSlotName::WEIGHT, weight_repartitioned_coord_1}, + {TensorSlotName::OUTPUT, linear_coord_1}, + }}}, + }}}, + // reduction: INPUT only — OUTPUT coords not distinct + {reduction_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, linear_coord_0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, linear_coord_1}, + }}}, + }}}, + // relu: on cpu0 only + {relu_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, reduced_coord}, + {TensorSlotName::OUTPUT, reduced_coord}, + }}}, + }}}, + }}; + + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = + create_distributed_ff_handle(ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = create_pcg_instance(ctx, + mpcg, + optimizer_attrs, + std::nullopt, + input_tensors, + ProfilingSettings{0, 0}, + device_handle, + FFIterationConfig{1_p}); + + perform_all_passes_for_pcg_instance(pcg_instance, + ProfilingSettings{0, 0}, + device_handle, + FFIterationConfig{1_p}); + }); + result.wait(); + } +} +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Reduction Op (GPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/2_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + ControllerTaskResult result = manager.start_controller([](RealmContext + &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 4_p; + positive_int in_channels = 8_p; + positive_int out_channels = 4_p; + + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, in_channels}}, DataType::FLOAT}; + + TensorShape weight_tensor_shape = TensorShape{ + TensorDims{FFOrdered{out_channels, in_channels}}, DataType::FLOAT}; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // input layer + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // weight layer + ParallelLayerAddedResult weights_layer = + pcg_add_input_layer(pcg, weight_tensor_shape); + parallel_tensor_guid_t t_weight = + require_only_key(weights_layer.outputs, TensorSlotName::OUTPUT); + + // repartition input along feature dim (dim 1) with degree 2 + RepartitionAttrs input_repartition_attrs{ + /*repartition_dim=*/ff_dim_t{nonnegative_int{1}}, + /*repartition_degree=*/2_p, + }; + ParallelLayerAddedResult input_repartition_operator = + add_parallel_layer(pcg, + make_layer_attrs(input_repartition_attrs), + {{TensorSlotName::INPUT, t_input}}, + /*weights=*/{}); + parallel_tensor_guid_t t_input_repartitioned = require_only_key( + input_repartition_operator.outputs, TensorSlotName::OUTPUT); + + // repartition weight along feature dim (dim 1) with degree 2 + // to match the repartitioned input + RepartitionAttrs weight_repartition_attrs{ + /*repartition_dim=*/ff_dim_t{nonnegative_int{1}}, + /*repartition_degree=*/2_p, + }; + ParallelLayerAddedResult weight_repartition_operator = + add_parallel_layer(pcg, + make_layer_attrs(weight_repartition_attrs), + {{TensorSlotName::INPUT, t_weight}}, + /*weights=*/{}); + parallel_tensor_guid_t t_weight_repartitioned = require_only_key( + weight_repartition_operator.outputs, TensorSlotName::OUTPUT); + + // linear with repartitioned input and weight + // shard_dim[-1]=2 → sum_degree=2 output + ParallelLayerAddedResult linear_operator = add_parallel_layer( + pcg, + ParallelLayerAttrs{PCGOperatorAttrs{LinearAttrs{out_channels, + /*use_bias=*/false, + DataType::FLOAT, + Activation::RELU, + std::nullopt}}, + std::nullopt}, + /*inputs=*/ + { + {TensorSlotName::INPUT, t_input_repartitioned}, + }, + /*weights=*/ + { + {TensorSlotName::WEIGHT, t_weight_repartitioned}, + }); + parallel_tensor_guid_t t_linear = + require_only_key(linear_operator.outputs, TensorSlotName::OUTPUT); + + // reduction degree=2 — sums partial results + ReductionAttrs reduction_attrs{/*reduction_degree=*/2_p}; + ParallelLayerAddedResult reduction_operator = + add_parallel_layer(pcg, + make_layer_attrs(reduction_attrs), + {{TensorSlotName::INPUT, t_linear}}, + /*weights=*/{}); + parallel_tensor_guid_t t_reduced = + require_only_key(reduction_operator.outputs, TensorSlotName::OUTPUT); + + // relu consumer + ParallelLayerAddedResult relu_operator = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + {{TensorSlotName::INPUT, t_reduced}}, + /*weights=*/{}); + + MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; + MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; + + // input: unsharded on gpu0 — 2 shard dims + ParallelTensorSpaceCoordinate input_coord{0_n, 0_n, FFOrdered{0_n, 0_n}}; + + // weight: unsharded on gpu0 — 2 shard dims + ParallelTensorSpaceCoordinate weight_coord{0_n, 0_n, FFOrdered{0_n, 0_n}}; + + // after repartition: input sharded along feature dim + ParallelTensorSpaceCoordinate input_repartitioned_coord_0{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + ParallelTensorSpaceCoordinate input_repartitioned_coord_1{ + 0_n, 0_n, FFOrdered{0_n, 1_n}}; + + // after repartition: weight sharded along feature dim + ParallelTensorSpaceCoordinate weight_repartitioned_coord_0{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + ParallelTensorSpaceCoordinate weight_repartitioned_coord_1{ + 0_n, 0_n, FFOrdered{0_n, 1_n}}; + + // linear output: partial sums — sum_component distinguishes them + // output has 2 shard dims [{4,1},{4,1}] + ParallelTensorSpaceCoordinate linear_coord_0{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + ParallelTensorSpaceCoordinate linear_coord_1{ + 1_n, 0_n, FFOrdered{0_n, 0_n}}; + + // reduced output: fully reduced on gpu0 + ParallelTensorSpaceCoordinate reduced_coord{ + 0_n, 0_n, FFOrdered{0_n, 0_n}}; + + MappedParallelComputationGraph mpcg{ + pcg, + { + // input: unsharded on gpu0 + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, input_coord}}}}}}}, + // weight: unsharded on gpu0 + {weights_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, weight_coord}}}}}}}, + // input repartition: OUTPUT only + {input_repartition_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, input_repartitioned_coord_0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, input_repartitioned_coord_1}, + }}}, + }}}, + // weight repartition: OUTPUT only + {weight_repartition_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, weight_repartitioned_coord_0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, weight_repartitioned_coord_1}, + }}}, + }}}, + // linear: INPUT + WEIGHT + OUTPUT per device + {linear_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, input_repartitioned_coord_0}, + {TensorSlotName::WEIGHT, weight_repartitioned_coord_0}, + {TensorSlotName::OUTPUT, linear_coord_0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, input_repartitioned_coord_1}, + {TensorSlotName::WEIGHT, weight_repartitioned_coord_1}, + {TensorSlotName::OUTPUT, linear_coord_1}, + }}}, + }}}, + // reduction: INPUT only — OUTPUT coords not distinct + {reduction_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, linear_coord_0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, linear_coord_1}, + }}}, + }}}, + // relu: on gpu0 only + {relu_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, reduced_coord}, + {TensorSlotName::OUTPUT, reduced_coord}, + }}}, + }}}, + }}; + + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = + create_distributed_ff_handle(ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = create_pcg_instance(ctx, + mpcg, + optimizer_attrs, + std::nullopt, + input_tensors, + ProfilingSettings{0, 0}, + device_handle, + FFIterationConfig{1_p}); + + perform_all_passes_for_pcg_instance(pcg_instance, + ProfilingSettings{0, 0}, + device_handle, + FFIterationConfig{1_p}); + }); + result.wait(); + } +} +} // namespace test diff --git a/lib/realm-execution/test/src/realm-execution/test_op_repartition.cc b/lib/realm-execution/test/src/realm-execution/test_op_repartition.cc new file mode 100644 index 0000000000..5974becae0 --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/test_op_repartition.cc @@ -0,0 +1,290 @@ +#include "internal/realm_test_utils.h" +#include "kernels/allocation.h" +#include "kernels/compare_tensor_accessors.h" +#include "kernels/copy_tensor_accessor.h" +#include "kernels/format_accessor_contents.h" +#include "kernels/tensor_accessor_reductions.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "realm-execution/distributed_ff_handle.h" +#include "realm-execution/dynamic_tensor_accessor_from_instance.h" +#include "realm-execution/pcg_instance.h" +#include "realm-execution/realm_context.h" +#include "realm-execution/realm_manager.h" +#include "task-spec/permissions.h" +#include "test/utils/doctest/check_kv.h" +#include "utils/containers/require_only_key.h" +#include + +namespace test { + +using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; + +template +static ParallelLayerAttrs make_layer_attrs(T const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + /*name=*/std::nullopt, + }; +}; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "RealmBackend e2e Training Repartition Op (CPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/2_p, /*num_gpus=*/0_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // repartition along batch dimension (dim 0) with degree 2 + RepartitionAttrs repartition_attrs{ + /*repartition_dim=*/ff_dim_t{nonnegative_int{0}}, + /*repartition_degree=*/2_p, + }; + ParallelLayerAddedResult repartition_operator = + add_parallel_layer(pcg, + make_layer_attrs(repartition_attrs), + {{TensorSlotName::INPUT, t_input}}, + /*weights=*/{}); + parallel_tensor_guid_t t_repartitioned = require_only_key( + repartition_operator.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult relu_operator = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + {{TensorSlotName::INPUT, t_repartitioned}}, + /*weights=*/{}); + + MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; + MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; + + // input: one shard on cpu0 (not yet repartitioned) + ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; + // after repartition: two shards along dim 0 + ParallelTensorSpaceCoordinate tensor_coord_shard0{ + 0_n, 0_n, FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord_shard1{ + 0_n, 0_n, FFOrdered{1_n}}; + + MappedParallelComputationGraph mpcg{ + pcg, + { + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + // repartition: OUTPUT only (no INPUT in binding) + {repartition_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord_shard0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord_shard1}, + }}}, + }}}, + {relu_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord_shard0}, + {TensorSlotName::OUTPUT, tensor_coord_shard0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord_shard1}, + {TensorSlotName::OUTPUT, tensor_coord_shard1}, + }}}, + }}}, + }}; + + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = + create_pcg_instance(ctx, + mpcg, + optimizer_attrs, + std::nullopt, + input_tensors, + ProfilingSettings{0, 0}, + device_handle, + FFIterationConfig{1_p}); + + perform_all_passes_for_pcg_instance(pcg_instance, + ProfilingSettings{0, 0}, + device_handle, + FFIterationConfig{1_p}); + }); + result.wait(); + } +} +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE( + "RealmBackend e2e Training Repartition Op (GPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/2_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // repartition along batch dimension (dim 0) with degree 2 + RepartitionAttrs repartition_attrs{ + /*repartition_dim=*/ff_dim_t{nonnegative_int{0}}, + /*repartition_degree=*/2_p, + }; + ParallelLayerAddedResult repartition_operator = + add_parallel_layer(pcg, + make_layer_attrs(repartition_attrs), + {{TensorSlotName::INPUT, t_input}}, + /*weights=*/{}); + parallel_tensor_guid_t t_repartitioned = require_only_key( + repartition_operator.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult relu_operator = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + {{TensorSlotName::INPUT, t_repartitioned}}, + /*weights=*/{}); + + MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; + MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; + + // input: one shard on gpu0 (not yet repartitioned) + ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; + // after repartition: two shards along dim 0 + ParallelTensorSpaceCoordinate tensor_coord_shard0{ + 0_n, 0_n, FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord_shard1{ + 0_n, 0_n, FFOrdered{1_n}}; + + MappedParallelComputationGraph mpcg{ + pcg, + { + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + // repartition: OUTPUT only (no INPUT in binding) + {repartition_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord_shard0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord_shard1}, + }}}, + }}}, + {relu_operator.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord_shard0}, + {TensorSlotName::OUTPUT, tensor_coord_shard0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord_shard1}, + {TensorSlotName::OUTPUT, tensor_coord_shard1}, + }}}, + }}}, + }}; + + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = + create_pcg_instance(ctx, + mpcg, + optimizer_attrs, + std::nullopt, + input_tensors, + ProfilingSettings{0, 0}, + device_handle, + FFIterationConfig{1_p}); + + perform_all_passes_for_pcg_instance(pcg_instance, + ProfilingSettings{0, 0}, + device_handle, + FFIterationConfig{1_p}); + }); + result.wait(); + } +} +} // namespace test diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc new file mode 100644 index 0000000000..632f08d239 --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -0,0 +1,472 @@ +#include "internal/realm_test_utils.h" +#include "kernels/allocation.h" +#include "kernels/compare_tensor_accessors.h" +#include "kernels/copy_tensor_accessor.h" +#include "kernels/format_accessor_contents.h" +#include "kernels/tensor_accessor_reductions.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "realm-execution/distributed_ff_handle.h" +#include "realm-execution/dynamic_tensor_accessor_from_instance.h" +#include "realm-execution/pcg_instance.h" +#include "realm-execution/realm_context.h" +#include "realm-execution/realm_manager.h" +#include "task-spec/permissions.h" +#include "test/utils/doctest/check_kv.h" +#include "utils/containers/require_only_key.h" +#include + +namespace test { + +using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; + +template +static ParallelLayerAttrs make_layer_attrs(T const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + /*name=*/std::nullopt, + }; +}; + +static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, + GenericTensorAccessorR const &last_epoch, + Allocator &allocator) { + return tensor_accessor_all( + compare_tensor_accessors_le(last_epoch, first_epoch, allocator)); +} + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Replicate Op (CPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/2_p, /*num_gpus=*/0_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + // 10,2 + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + // 10,2 + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + GenericTensorAccessorW label_tensor = + allocator.allocate_tensor(label_tensor_shape); + + // construct computation graph + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // input tensor + // 10, 16 + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + // parallel layer -> input tensor + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> input tensor 2 + ParallelLayerAddedResult inputs_layer_2 = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input_2 = + require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + // binary ADD attribute + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + // parallel layer -> perform add + ParallelLayerAddedResult add_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(add_attrs), + { + { + TensorSlotName::LHS_INPUT, + t_input, + }, + { + TensorSlotName::RHS_INPUT, + t_input_2, + }, + }, + {/* weight */}); + + parallel_tensor_guid_t t_add_1 = + require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform replicate + const positive_int replicate_degree = 2_p; + ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); + ParallelLayerAddedResult repl_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(repl_attrs), + { + { + TensorSlotName::INPUT, + t_add_1, + }, + }, + /*weight=*/{}); + // output of replicate layer + parallel_tensor_guid_t t_repl_1 = + require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform RelU + ParallelLayerAddedResult relu_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_repl_1, + }, + }, + /*weights=*/{}); + // output of relu layer + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + // machine + MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; + MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; + + ParallelTensorSpaceCoordinate tensor_coord0{ + /* sum_component */ 0_n, + /* discard_copy_component */ 0_n, + /*shard_component*/ FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{ + /* sum_component */ 0_n, + /* discard_copy_component */ 1_n, + /*shard_component*/ FFOrdered{0_n}}; + MappedParallelComputationGraph mpcg{ + pcg, + {{inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}, + {relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}}, + }; + + MappedOperatorTaskGroup loss_mapping{ + {{cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}}}}; + + // instantiate computation graph + LossAttrs loss_attrs = LossAttrs{ + NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/std::nullopt, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 1; + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + } + }); + result.wait(); + } +} + +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Replicate Op (GPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/2_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + // 10,2 + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + // 10,2 + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + GenericTensorAccessorW label_tensor = + allocator.allocate_tensor(label_tensor_shape); + + // construct computation graph + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // input tensor + // 10, 16 + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + // parallel layer -> input tensor + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> input tensor 2 + ParallelLayerAddedResult inputs_layer_2 = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input_2 = + require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + // binary ADD attribute + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + // parallel layer -> perform add + ParallelLayerAddedResult add_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(add_attrs), + { + { + TensorSlotName::LHS_INPUT, + t_input, + }, + { + TensorSlotName::RHS_INPUT, + t_input_2, + }, + }, + {/* weight */}); + + parallel_tensor_guid_t t_add_1 = + require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform replicate + const positive_int replicate_degree = 2_p; + ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); + ParallelLayerAddedResult repl_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(repl_attrs), + { + { + TensorSlotName::INPUT, + t_add_1, + }, + }, + /*weight=*/{}); + // output of replicate layer + parallel_tensor_guid_t t_repl_1 = + require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform RelU + ParallelLayerAddedResult relu_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_repl_1, + }, + }, + /*weights=*/{}); + // output of relu layer + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + // machine + MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; + MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; + ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{0_n, 1_n, FFOrdered{0_n}}; + MappedParallelComputationGraph mpcg{ + pcg, + { + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}}}}, + {relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}, + }, + }; + + MappedOperatorTaskGroup loss_mapping{ + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}}}}; + + // instantiate computation graph + LossAttrs loss_attrs = LossAttrs{ + NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/std::nullopt, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 1; + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + } + }); + result.wait(); + } +} +} // namespace test diff --git a/lib/task-spec/include/task-spec/dynamic_graph/parallel_op_utils.h b/lib/task-spec/include/task-spec/dynamic_graph/parallel_op_utils.h new file mode 100644 index 0000000000..095c9edc41 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/parallel_op_utils.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_PARALLEL_OP_UTILS_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_PARALLEL_OP_UTILS_H + +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" +#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" + +namespace FlexFlow { + +inline bool is_parallel_op_attrs(DynamicNodeAttrs const &n) { + if (!n.op_attrs.has_value()) { + return false; + } + if (!n.op_attrs.value().has()) { + return false; + } + PCGOperatorAttrs pcg = n.op_attrs.value().get(); + return pcg.has() || pcg.has() || + pcg.has() || pcg.has(); +} + +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_PARALLEL_OP_UTILS_H diff --git a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc index 4c1b9d4609..becb068a1d 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc @@ -9,6 +9,7 @@ #include "task-spec/dynamic_graph/dynamic_task_type.h" #include "task-spec/dynamic_graph/dynamic_tensor_slot.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" +#include "task-spec/dynamic_graph/parallel_op_utils.h" #include "utils/bidict/algorithms/bidict_from_pairs.h" #include "utils/bidict/algorithms/unordered_set_of.h" #include "utils/containers/contains_key.h" @@ -31,9 +32,26 @@ bool value_is_mapped(DynamicValueAttrs const &n) { bool no_part_of_graph_is_copy_inserted(DynamicOpenDataflowGraph const &g) { auto slot_is_mapped = [](DynamicTensorSlot const &) -> bool { return false; }; - - return no_part_of_dynamic_graph_satisfies( - g, node_is_copy, value_is_mapped, slot_is_mapped); + // check all non-replicate invocations + for (DynamicNodeInvocation const &i : g.invocations) { + if (is_parallel_op_attrs(i.node_attrs)) { + continue; // parallel tensors have mapping set by design + } + if (node_is_copy(i.node_attrs)) { + return false; + } + for (auto const &[slot, value] : i.inputs) { + if (value_is_mapped(value)) { + return false; + } + } + for (auto const &[slot, value] : i.outputs) { + if (value_is_mapped(value)) { + return false; + } + } + } + return true; } bool graph_is_fully_copy_inserted(DynamicOpenDataflowGraph const &g) { @@ -85,6 +103,11 @@ std::unordered_set perform_copy_insertion_for_invocation( std::unordered_map const &unmapped_value_to_mapped_source_value) { + // parallel op nodes have no MappedOperatorTaskGroup — + // pass through unchanged, no copies needed + if (is_parallel_op_attrs(i.node_attrs)) { + return {i}; + } MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); auto map_tensor = [&](DynamicTensorSlot const &slot, @@ -157,6 +180,14 @@ DynamicOpenDataflowGraph std::unordered_map unmapped_value_to_mapped_source_value; for (DynamicNodeInvocation const &i : g.invocations) { + // parallel op nodes have no MappedOperatorTaskGroup — + // output mapping already fully set, maps to itself + if (is_parallel_op_attrs(i.node_attrs)) { + for (auto const &[slot, value] : i.outputs) { + unmapped_value_to_mapped_source_value.insert(std::pair{value, value}); + } + continue; + } for (auto const &[slot, value] : i.outputs) { unmapped_value_to_mapped_source_value.insert( std::pair{value, diff --git a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc index bf9fe1d3a0..3a668feba1 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc @@ -1,4 +1,5 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/parallel_op_utils.h" #include "utils/containers/all_of.h" #include "utils/containers/contains_duplicates.h" #include "utils/containers/flatmap.h" @@ -149,6 +150,13 @@ std::pair #include #include namespace FlexFlow { +static bidict + get_input_mapping_for_parallel_op( + MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &layer) { + + // get_incoming_edges returns map + // replicate has exactly one input + auto [input_slot_name, input_edge] = + get_only(get_incoming_edges(mpcg.pcg, layer)); + + parallel_layer_guid_t producer_layer = get_src_layer(input_edge); + TensorSlotName producer_slot = get_src_layer_output_slot_name(input_edge); + + return get_tensor_bindings_for_slot_name(mpcg.mapped_tasks.at(producer_layer), + producer_slot); +} + +static std::unordered_map + get_consumers_of_tensor(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &tensor) { + parallel_layer_guid_t producer_layer = get_source_layer(mpcg.pcg, tensor); + + std::unordered_map result; + // get_outgoing_edges returns unordered_set + for (ParallelComputationGraphEdge const &edge : + get_outgoing_edges(mpcg.pcg, producer_layer)) { + if (get_parallel_tensor(edge) == tensor) { + result.insert( + std::pair{get_dst_layer(edge), get_dst_layer_input_slot_name(edge)}); + } + } + return result; +} + +static bidict + build_output_mapping_for_parallel_op( + MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &layer) { + + auto [output_slot_name, output_tensor_guid] = + get_only(get_outgoing_tensors(mpcg.pcg, layer)); + + auto consumers = get_consumers_of_tensor(mpcg, output_tensor_guid); + ASSERT(!consumers.empty()); + + // union all consumer bindings — each consumer shard maps to a distinct + // (discard_copy, machine) pair since replicas are always on different machines + bidict result; + for (auto const &[consumer_layer, slot_name] : consumers) { + MappedOperatorTaskGroup consumer_mapping = + mpcg.mapped_tasks.at(consumer_layer); + bidict binding = + get_tensor_bindings_for_slot_name(consumer_mapping, slot_name); + for (auto const &[p, m] : binding) { + result.equate(p, m); + } + } + return result; +} + +static DynamicNodeInvocation + build_parallel_op_invocation(parallel_layer_guid_t const &layer, + ParallelLayerAttrs const &attrs, + MappedParallelComputationGraph const &mpcg) { + auto [input_slot_name, input_tensor_guid] = + get_only(get_incoming_tensors(mpcg.pcg, layer)); + auto incoming = get_incoming_tensors(mpcg.pcg, layer); + ASSERT(!incoming.empty(), + "replicate layer has no incoming tensors — " + "check PCG edge construction in test"); + + ParallelTensorAttrs input_attrs = + get_parallel_tensor_attrs(mpcg.pcg, input_tensor_guid); + + DynamicValueAttrs input_value{ + /*tensor_guid=*/dynamic_tensor_guid_t{input_tensor_guid}, + /*parallel_tensor_shape=*/input_attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/get_input_mapping_for_parallel_op(mpcg, layer), + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + + auto [output_slot_name, output_tensor_guid] = + get_only(get_outgoing_tensors(mpcg.pcg, layer)); + ParallelTensorAttrs output_attrs = + get_parallel_tensor_attrs(mpcg.pcg, output_tensor_guid); + + DynamicValueAttrs output_value{ + /*tensor_guid=*/dynamic_tensor_guid_t{output_tensor_guid}, + /*parallel_tensor_shape=*/output_attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/build_output_mapping_for_parallel_op(mpcg, layer), + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + DynamicNodeAttrs node_attrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, + /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, + /*per_device_op_state=*/std::nullopt, + }; + + DynamicNodeInvocation invocation_node{ + /*inputs=*/{ + {DynamicTensorSlot{input_slot_name, std::nullopt}, input_value}}, + /*node_attrs=*/node_attrs, + /*outputs=*/ + {{DynamicTensorSlot{output_slot_name, std::nullopt}, output_value}}, + }; + return invocation_node; +} + DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( MappedParallelComputationGraph const &mpcg) { DynamicOpenDataflowGraph result = make_empty_dynamic_open_dataflow_graph(); for (auto const &[layer, attrs] : get_parallel_layer_attrs_mapping(mpcg.pcg)) { + + if (is_parallel_op(attrs.op_attrs)) { + // build replicate invocation + DynamicNodeInvocation parallel_inv = + build_parallel_op_invocation(layer, attrs, mpcg); + result.invocations.emplace(parallel_inv); + continue; + } + DynamicNodeAttrs result_attrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, @@ -73,7 +199,6 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( result.invocations.emplace(result_inputs, result_attrs, result_outputs); } - return result; } diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index 0cee06368f..036579c80a 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -1,7 +1,9 @@ #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "task-spec/dynamic_graph/parallel_op_utils.h" #include "utils/containers/are_all_same.h" +#include "utils/containers/get_only.h" #include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/transform.h" @@ -110,6 +112,51 @@ DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( }; } +static std::unordered_set + perform_pass_expansion_for_parallel_op( + DynamicNodeInvocation const &invocation) { + + auto const &[input_slot, input] = get_only(invocation.inputs); + + auto to_fwd = [](DynamicTensorSlot const &k, DynamicValueAttrs const &v) { + return std::pair{ + pass_expand_slot(k, FwbTensorType::FORWARD), + pass_expand_value(v, FwbTensorType::FORWARD), + }; + }; + + auto to_grad = [](DynamicTensorSlot const &k, DynamicValueAttrs const &v) { + return std::pair{ + pass_expand_slot(k, FwbTensorType::GRADIENT), + pass_expand_value(v, FwbTensorType::GRADIENT), + }; + }; + + DynamicNodeInvocation fwd{ + /*inputs=*/{{pass_expand_slot(input_slot, FwbTensorType::FORWARD), + pass_expand_value(input, FwbTensorType::FORWARD)}}, + /*node_attrs=*/ + pass_expand_node(invocation.node_attrs, DynamicTaskType::FWD), + /*outputs=*/transform(invocation.outputs, to_fwd), + }; + + DynamicNodeAttrs bwd_node = invocation.node_attrs; + bwd_node.task_type = DynamicTaskType::BWD; + + DynamicNodeInvocation bwd{ + /*inputs=*/merge_disjoint_maps(std::vector{ + transform(invocation.outputs, to_fwd), + transform(invocation.outputs, to_grad), + }), + /*node_attrs=*/bwd_node, + /*outputs=*/ + {{pass_expand_slot(input_slot, FwbTensorType::GRADIENT), + pass_expand_value(input, FwbTensorType::GRADIENT)}}, + }; + + return {fwd, bwd}; +} + DynamicOpenDataflowGraph perform_pass_expansion(DynamicOpenDataflowGraph const &g) { @@ -117,6 +164,9 @@ DynamicOpenDataflowGraph DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set( g, [](DynamicNodeInvocation const &invocation) { + if (is_parallel_op_attrs(invocation.node_attrs)) { + return perform_pass_expansion_for_parallel_op(invocation); + } if (invocation.inputs.empty()) { return std::unordered_set{ perform_fwd_pass_expansion_for_invocation(invocation), diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index fb6efb96d0..c049a35cb1 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -1,6 +1,7 @@ #include "task-spec/dynamic_graph/shard_expansion.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" +#include "task-spec/dynamic_graph/parallel_op_utils.h" #include "utils/bidict/algorithms/filter_keys.h" #include "utils/containers/get_only.h" #include "utils/containers/map_values2.h" @@ -18,6 +19,10 @@ bool value_is_shard_expanded(DynamicValueAttrs const &n) { return n.shard_coord.has_value(); } +static bool has_task_type(DynamicNodeAttrs const &n, DynamicTaskType t) { + return n.task_type.has_value() && n.task_type.value() == t; +} + bool no_part_of_graph_is_shard_expanded(DynamicOpenDataflowGraph const &g) { auto slot_is_shard_expanded = [](DynamicTensorSlot const &) -> bool { return false; @@ -39,7 +44,6 @@ bool graph_is_fully_shard_expanded(DynamicOpenDataflowGraph const &g) { value_is_shard_expanded, slot_is_shard_expanded); } - static bidict restrict_tensor_mapping_keys_to_coord( bidict const @@ -85,6 +89,339 @@ static DynamicNodeInvocation shard_invocation_for_binding( }; } +static std::unordered_set + perform_shard_expansion_one_to_many( + DynamicNodeInvocation const &i, + std::function output_to_input_coord) { + + if (has_task_type(i.node_attrs, DynamicTaskType::FWD)) { + auto const &[input_slot, input] = get_only(i.inputs); + auto const &[output_slot, output] = get_only(i.outputs); + + bidict + output_mapping = assert_unwrap(output.mapping); + + return transform(output_mapping.left_values(), + [&](ParallelTensorSpaceCoordinate const &p) { + ParallelTensorSpaceCoordinate input_p = + output_to_input_coord(p); + return shard_invocation_for_binding( + i, + output_mapping.at_l(p), + OperatorAtomicTaskShardBinding{{ + {input_slot.slot_name, input_p}, + {output_slot.slot_name, p}, + }}); + }); + } + + // BWD case — inputs are OUTPUT/FWD and OUTPUT/GRAD, output is INPUT/GRAD + std::optional output_grad_opt; + std::optional output_fwd_opt; + std::optional output_grad_slot_opt; + std::optional output_fwd_slot_opt; + + for (auto const &[slot, value] : i.inputs) { + if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) { + output_grad_slot_opt = slot; + output_grad_opt = value; + } else { + output_fwd_slot_opt = slot; + output_fwd_opt = value; + } + } + + DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt); + DynamicValueAttrs output_fwd = assert_unwrap(output_fwd_opt); + DynamicTensorSlot output_grad_slot = assert_unwrap(output_grad_slot_opt); + DynamicTensorSlot output_fwd_slot = assert_unwrap(output_fwd_slot_opt); + auto const &[input_grad_slot, input_grad] = get_only(i.outputs); + + bidict + input_grad_mapping = assert_unwrap(input_grad.mapping); + + // iterate over input_grad coords (the "many" side) + return transform( + input_grad_mapping.left_values(), + [&](ParallelTensorSpaceCoordinate const &p) { + // map input_grad coord to output_grad coord + ParallelTensorSpaceCoordinate output_p = output_to_input_coord(p); + MachineSpaceCoordinate dst_machine = input_grad_mapping.at_l(p); + + bidict + output_grad_mapping = assert_unwrap(output_grad.mapping); + + DynamicValueAttrs sharded_output_grad = output_grad; + sharded_output_grad.mapping = + bidict{ + {output_p, output_grad_mapping.at_l(output_p)}}; + sharded_output_grad.shard_coord = output_p; + + DynamicValueAttrs sharded_output_fwd = output_fwd; + sharded_output_fwd.mapping = + bidict{ + {output_p, output_grad_mapping.at_l(output_p)}}; + sharded_output_fwd.shard_coord = output_p; + + DynamicValueAttrs sharded_input_grad = input_grad; + sharded_input_grad.mapping = + bidict{ + {p, dst_machine}}; + sharded_input_grad.shard_coord = p; + + DynamicNodeAttrs sharded_node = i.node_attrs; + sharded_node.device_coord = dst_machine; + + return DynamicNodeInvocation{ + /*inputs=*/{ + {output_fwd_slot, sharded_output_fwd}, + {output_grad_slot, sharded_output_grad}, + }, + /*node_attrs=*/sharded_node, + /*outputs=*/ + { + {input_grad_slot, sharded_input_grad}, + }, + }; + }); +} +static std::unordered_set + perform_shard_expansion_many_to_one( + DynamicNodeInvocation const &i, + std::function input_to_output_coord) { + + if (has_task_type(i.node_attrs, DynamicTaskType::FWD)) { + auto const &[input_slot, input] = get_only(i.inputs); + auto const &[output_slot, output] = get_only(i.outputs); + + bidict + input_mapping = assert_unwrap(input.mapping); + bidict + output_mapping = assert_unwrap(output.mapping); + + return transform(input_mapping.left_values(), + [&](ParallelTensorSpaceCoordinate const &p) { + ParallelTensorSpaceCoordinate output_p = + input_to_output_coord(p); + MachineSpaceCoordinate dst_machine = + output_mapping.at_l(output_p); + return shard_invocation_for_binding( + i, + dst_machine, + OperatorAtomicTaskShardBinding{{ + {input_slot.slot_name, p}, + {output_slot.slot_name, output_p}, + }}); + }); + } + + // BWD case + std::optional output_grad_opt; + std::optional output_fwd_opt; + std::optional output_grad_slot_opt; + std::optional output_fwd_slot_opt; + + for (auto const &[slot, value] : i.inputs) { + if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) { + output_grad_slot_opt = slot; + output_grad_opt = value; + } else { + output_fwd_slot_opt = slot; + output_fwd_opt = value; + } + } + + DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt); + DynamicValueAttrs output_fwd = assert_unwrap(output_fwd_opt); + DynamicTensorSlot output_grad_slot = assert_unwrap(output_grad_slot_opt); + DynamicTensorSlot output_fwd_slot = assert_unwrap(output_fwd_slot_opt); + auto const &[input_grad_slot, input_grad] = get_only(i.outputs); + + bidict + output_grad_mapping = assert_unwrap(output_grad.mapping); + bidict + input_grad_mapping = assert_unwrap(input_grad.mapping); + + // group output_grad coords by their corresponding input_grad coord + std::unordered_map> + input_grad_to_output_grads; + for (auto const &p : output_grad_mapping.left_values()) { + input_grad_to_output_grads[input_to_output_coord(p)].insert(p); + } + + std::unordered_set result; + for (auto const &[input_grad_p, output_grad_coords] : + input_grad_to_output_grads) { + + MachineSpaceCoordinate dst_machine = input_grad_mapping.at_l(input_grad_p); + + // subset output_grad mapping to just this group's coords + bidict + replica_mapping; + for (auto const &p : output_grad_coords) { + replica_mapping.equate(p, output_grad_mapping.at_l(p)); + } + + DynamicValueAttrs sharded_output_grad = output_grad; + sharded_output_grad.mapping = replica_mapping; + sharded_output_grad.shard_coord = input_grad_p; + + DynamicValueAttrs sharded_output_fwd = output_fwd; + sharded_output_fwd.mapping = replica_mapping; + sharded_output_fwd.shard_coord = input_grad_p; + + DynamicValueAttrs sharded_input_grad = input_grad; + sharded_input_grad.mapping = + bidict{ + {input_grad_p, dst_machine}}; + sharded_input_grad.shard_coord = input_grad_p; + + DynamicNodeAttrs sharded_node = i.node_attrs; + sharded_node.device_coord = dst_machine; + + result.insert(DynamicNodeInvocation{ + /*inputs=*/{ + {output_fwd_slot, sharded_output_fwd}, + {output_grad_slot, sharded_output_grad}, + }, + /*node_attrs=*/sharded_node, + /*outputs=*/ + { + {input_grad_slot, sharded_input_grad}, + }, + }); + } + return result; +} + +// Replicate/Reduction FWD — output has discard_copy=0..N-1, input always discard_copy=0 +static std::unordered_set + perform_shard_expansion_for_replicate(DynamicNodeInvocation const &i) { + return perform_shard_expansion_one_to_many( + i, [](ParallelTensorSpaceCoordinate const &p) { + return ParallelTensorSpaceCoordinate{ + p.sum_component, nonnegative_int{0}, p.shard_components}; + }); +} + +// Replicate BWD — many discard_copy inputs → one discard_copy=0 output +static std::unordered_set + perform_shard_expansion_for_replicate_bwd(DynamicNodeInvocation const &i) { + return perform_shard_expansion_many_to_one( + i, [](ParallelTensorSpaceCoordinate const &p) { + return ParallelTensorSpaceCoordinate{ + p.sum_component, nonnegative_int{0}, p.shard_components}; + }); +} + +// Repartition FWD — output coord (high) → input coord (low) +static std::unordered_set + perform_shard_expansion_for_repartition(DynamicNodeInvocation const &i) { + RepartitionAttrs attrs = i.node_attrs.op_attrs.value() + .get() + .get(); + relative_ff_dim_t rel_dim = + relative_ff_dim_t_from_ff_dim_t(attrs.repartition_dim); + nonnegative_int degree = + attrs.repartition_degree.nonnegative_int_from_positive_int(); + + return perform_shard_expansion_one_to_many( + i, [=](ParallelTensorSpaceCoordinate const &p) { + FFOrdered input_shard = p.shard_components; + input_shard.at(rel_dim) = + p.shard_components.at(rel_dim) / degree; // ← / not % + return ParallelTensorSpaceCoordinate{ + p.sum_component, p.discard_copy_component, input_shard}; + }); +} + +// Repartition BWD — output_grad coord (high) → input_grad coord (low) +static std::unordered_set + perform_shard_expansion_for_repartition_bwd( + DynamicNodeInvocation const &i) { + RepartitionAttrs attrs = i.node_attrs.op_attrs.value() + .get() + .get(); + relative_ff_dim_t rel_dim = + relative_ff_dim_t_from_ff_dim_t(attrs.repartition_dim); + nonnegative_int degree = + attrs.repartition_degree.nonnegative_int_from_positive_int(); + + return perform_shard_expansion_many_to_one( + i, [=](ParallelTensorSpaceCoordinate const &p) { + FFOrdered input_shard = p.shard_components; + input_shard.at(rel_dim) = + p.shard_components.at(rel_dim) / degree; // ← / not % + return ParallelTensorSpaceCoordinate{ + p.sum_component, p.discard_copy_component, input_shard}; + }); +} + +// Combine FWD — input coord (high) → output coord (low) +static std::unordered_set + perform_shard_expansion_for_combine(DynamicNodeInvocation const &i) { + CombineAttrs attrs = + i.node_attrs.op_attrs.value().get().get(); + relative_ff_dim_t rel_dim = + relative_ff_dim_t_from_ff_dim_t(attrs.combine_dim); + nonnegative_int degree = + attrs.combine_degree.nonnegative_int_from_positive_int(); + + return perform_shard_expansion_many_to_one( + i, [=](ParallelTensorSpaceCoordinate const &p) { + FFOrdered output_shard = p.shard_components; + output_shard.at(rel_dim) = + p.shard_components.at(rel_dim) / degree; // ← correct + return ParallelTensorSpaceCoordinate{ + p.sum_component, p.discard_copy_component, output_shard}; + }); +} + +// Combine BWD — input_grad coord (high) → output_grad coord (low) +static std::unordered_set + perform_shard_expansion_for_combine_bwd(DynamicNodeInvocation const &i) { + CombineAttrs attrs = + i.node_attrs.op_attrs.value().get().get(); + relative_ff_dim_t rel_dim = + relative_ff_dim_t_from_ff_dim_t(attrs.combine_dim); + nonnegative_int degree = + attrs.combine_degree.nonnegative_int_from_positive_int(); + + return perform_shard_expansion_one_to_many( + i, [=](ParallelTensorSpaceCoordinate const &p) { + FFOrdered output_shard = p.shard_components; + output_shard.at(rel_dim) = + p.shard_components.at(rel_dim) / degree; // ← / not % + return ParallelTensorSpaceCoordinate{ + p.sum_component, p.discard_copy_component, output_shard}; + }); +} + +// Reduction FWD — input coord (sum=0..N-1) → output coord (sum=0) +static std::unordered_set + perform_shard_expansion_for_reduction(DynamicNodeInvocation const &i) { + return perform_shard_expansion_many_to_one( + i, [](ParallelTensorSpaceCoordinate const &p) { + return ParallelTensorSpaceCoordinate{ + nonnegative_int{0}, // ← output always has sum=0 + p.discard_copy_component, + p.shard_components}; + }); +} + +// Reduction BWD — output_grad coord (sum=0) → input_grad coord (sum=0..N-1) +static std::unordered_set + perform_shard_expansion_for_reduction_bwd(DynamicNodeInvocation const &i) { + return perform_shard_expansion_many_to_one( + i, [](ParallelTensorSpaceCoordinate const &p) { + return ParallelTensorSpaceCoordinate{ + p.sum_component, nonnegative_int{0}, p.shard_components}; + }); +} + static std::unordered_set perform_shard_expansion_for_copy(DynamicNodeInvocation const &i) { auto [input_slot, input] = get_only(i.inputs); @@ -114,6 +451,47 @@ static std::unordered_set }); } +static std::unordered_set + perform_shard_expansion_for_parallel_op(DynamicNodeInvocation const &i) { + ASSERT(is_parallel_op_attrs(i.node_attrs)); + + PCGOperatorAttrs const pcg = + i.node_attrs.op_attrs.value().get(); + + // forward dispatch + if (has_task_type(i.node_attrs, DynamicTaskType::FWD)) { + if (pcg.has()) { + return perform_shard_expansion_for_replicate(i); + } + if (pcg.has()) { + return perform_shard_expansion_for_repartition(i); + } + if (pcg.has()) { + return perform_shard_expansion_for_combine(i); + } + if (pcg.has()) { + return perform_shard_expansion_for_reduction(i); + } + } + + // backward dispatch + if (has_task_type(i.node_attrs, DynamicTaskType::BWD)) { + if (pcg.has()) { + return perform_shard_expansion_for_replicate_bwd(i); + } + if (pcg.has()) { + return perform_shard_expansion_for_repartition_bwd(i); + } + if (pcg.has()) { + return perform_shard_expansion_for_combine_bwd(i); + } + if (pcg.has()) { + return perform_shard_expansion_for_reduction_bwd(i); + } + } + PANIC("unhandled parallel op task_type: {}", i.node_attrs.task_type); +} + std::unordered_set perform_shard_expansion_for_invocation(DynamicNodeInvocation const &i) { if (i.node_attrs.op_attrs.has_value() && @@ -121,6 +499,10 @@ std::unordered_set return perform_shard_expansion_for_copy(i); } + if (is_parallel_op_attrs(i.node_attrs)) { + return perform_shard_expansion_for_parallel_op(i); + } + MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); std::unordered_set shard_machine_coords = diff --git a/lib/task-spec/src/task-spec/ops/impl/element_binary.cc b/lib/task-spec/src/task-spec/ops/impl/element_binary.cc index 13465d7a5f..c8460af538 100644 --- a/lib/task-spec/src/task-spec/ops/impl/element_binary.cc +++ b/lib/task-spec/src/task-spec/ops/impl/element_binary.cc @@ -36,8 +36,8 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementBinaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_binary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_binary(); ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); device_handle_t handle = acc.get_ff_handle(); @@ -62,8 +62,8 @@ static std::optional backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementBinaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_binary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_binary(); ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); device_handle_t handle = acc.get_ff_handle(); diff --git a/lib/task-spec/src/task-spec/ops/impl/element_unary.cc b/lib/task-spec/src/task-spec/ops/impl/element_unary.cc index d66ff9ab8d..9a092b90b8 100644 --- a/lib/task-spec/src/task-spec/ops/impl/element_unary.cc +++ b/lib/task-spec/src/task-spec/ops/impl/element_unary.cc @@ -35,8 +35,8 @@ static std::optional ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementUnaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_unary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_unary(); return profile(forward_kernel, profiling, @@ -62,8 +62,8 @@ static std::optional ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementUnaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_unary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_unary(); return profile(backward_kernel, profiling,