diff --git a/runtime/executor/program.cpp b/runtime/executor/program.cpp index 82991011127..6b48cf8aeb2 100644 --- a/runtime/executor/program.cpp +++ b/runtime/executor/program.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -560,8 +561,18 @@ Result Program::LoadSegment( // Could fail if offset and size are out of bound for the data, or if this // is reading from a file and fails, or for many other reasons depending on // the implementation of the loader. + uint64_t seg_offset = segment->offset(); + uint64_t absolute_offset = 0; + ET_CHECK_OR_RETURN_ERROR( + !c10::add_overflows( + segment_base_offset_, seg_offset, &absolute_offset) && + absolute_offset <= SIZE_MAX, + InvalidProgram, + "segment_base_offset %zu + segment offset %" PRIu64 " overflows", + segment_base_offset_, + seg_offset); return loader_->load( - segment_base_offset_ + segment->offset(), segment->size(), segment_info); + static_cast(absolute_offset), segment->size(), segment_info); } Error Program::load_mutable_subsegment_into( @@ -628,8 +639,15 @@ Error Program::load_mutable_subsegment_into( auto segment = internal_program_->segments()->Get(segment_offsets->segment_index()); - // Check size - if (offset + size > segment->size()) { + // Check size (with overflow protection) + size_t end_offset = 0; + ET_CHECK_OR_RETURN_ERROR( + !c10::add_overflows(offset, size, &end_offset), + InvalidProgram, + "offset %zu + size %zu overflows", + offset, + size); + if (end_offset > segment->size()) { ET_LOG( Error, "offset %zu + size %zu out of range > %" PRIu64, @@ -644,9 +662,26 @@ Error Program::load_mutable_subsegment_into( segment_offsets->segment_index(), nullptr); - // Load the data - return loader_->load_into( - segment_base_offset_ + segment->offset() + offset, size, info, buffer); + // Load the data (with overflow protection on the addition chain) + uint64_t seg_offset = segment->offset(); + uint64_t base_plus_seg_64 = 0; + ET_CHECK_OR_RETURN_ERROR( + !c10::add_overflows( + segment_base_offset_, seg_offset, &base_plus_seg_64) && + base_plus_seg_64 <= SIZE_MAX, + InvalidProgram, + "segment_base_offset %zu + segment offset %" PRIu64 " overflows", + segment_base_offset_, + seg_offset); + size_t base_plus_seg = static_cast(base_plus_seg_64); + size_t total_offset = 0; + ET_CHECK_OR_RETURN_ERROR( + !c10::add_overflows(base_plus_seg, offset, &total_offset), + InvalidProgram, + "segment base+offset %zu + subsegment offset %zu overflows", + base_plus_seg, + offset); + return loader_->load_into(total_offset, size, info, buffer); } } // namespace ET_RUNTIME_NAMESPACE