diff --git a/runtime/core/portable_type/tensor.h b/runtime/core/portable_type/tensor.h index 775bccc1b52..1a8644928b5 100644 --- a/runtime/core/portable_type/tensor.h +++ b/runtime/core/portable_type/tensor.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include @@ -118,33 +119,57 @@ class Tensor { /// Returns a pointer of type T to the constant underlying data blob. template inline const T* const_data_ptr() const { + ET_CHECK_MSG( + numel() == 0 || impl_->data() != nullptr, + "Tensor has non-zero numel (%zd) but null data pointer", + (ssize_t)numel()); return impl_->data(); } /// Returns a pointer to the constant underlying data blob. inline const void* const_data_ptr() const { + ET_CHECK_MSG( + numel() == 0 || impl_->data() != nullptr, + "Tensor has non-zero numel (%zd) but null data pointer", + (ssize_t)numel()); return impl_->data(); } /// Returns a pointer of type T to the mutable underlying data blob. template inline T* mutable_data_ptr() const { + ET_CHECK_MSG( + numel() == 0 || impl_->data() != nullptr, + "Tensor has non-zero numel (%zd) but null data pointer", + (ssize_t)numel()); return impl_->mutable_data(); } /// Returns a pointer to the mutable underlying data blob. inline void* mutable_data_ptr() const { + ET_CHECK_MSG( + numel() == 0 || impl_->data() != nullptr, + "Tensor has non-zero numel (%zd) but null data pointer", + (ssize_t)numel()); return impl_->mutable_data(); } /// DEPRECATED: Use const_data_ptr or mutable_data_ptr instead. template ET_DEPRECATED inline T* data_ptr() const { + ET_CHECK_MSG( + numel() == 0 || impl_->data() != nullptr, + "Tensor has non-zero numel (%zd) but null data pointer", + (ssize_t)numel()); return impl_->mutable_data(); } /// DEPRECATED: Use const_data_ptr or mutable_data_ptr instead. ET_DEPRECATED inline void* data_ptr() const { + ET_CHECK_MSG( + numel() == 0 || impl_->data() != nullptr, + "Tensor has non-zero numel (%zd) but null data pointer", + (ssize_t)numel()); return impl_->mutable_data(); }