Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 96 additions & 13 deletions src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,33 @@ class kernel_and_merge {
Tab *acc_buff);
};

template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tab>
auto run_dequantized_integrated_kernel(
strategy &strat, const Tlo *a_ptr, const Tro *b_panel, Tr *c_ptr, int ldc,
unsigned int m_size, unsigned int n_size, int kern_k, const int32_t *offset_col_bias,
const DequantizeFloat &dq, const Tr *offset_bias, const Activation &act, bool accumulate,
Tab *acc_buff, const int32_t *row_sum, int)
-> decltype(strat.kernel(a_ptr, b_panel, c_ptr, ldc, m_size, n_size, kern_k,
offset_col_bias, dq, offset_bias, act, accumulate, acc_buff,
row_sum, kern_k),
void())
{
strat.kernel(a_ptr, b_panel, c_ptr, ldc, m_size, n_size, kern_k,
offset_col_bias, dq, offset_bias, act, accumulate, acc_buff,
row_sum, kern_k);
}

template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tab>
void run_dequantized_integrated_kernel(
strategy &strat, const Tlo *a_ptr, const Tro *b_panel, Tr *c_ptr, int ldc,
unsigned int m_size, unsigned int n_size, int kern_k, const int32_t *offset_col_bias,
const DequantizeFloat &dq, const Tr *offset_bias, const Activation &act, bool accumulate,
Tab *acc_buff, const int32_t *, long)
{
strat.kernel(a_ptr, b_panel, c_ptr, ldc, m_size, n_size, kern_k,
offset_col_bias, dq, offset_bias, act, accumulate, acc_buff);
}

// Run a kernel and call the separate merge step
template<>
template<typename strategy, typename Tlo, typename Tro, typename Tr, typename Tri, typename Tab>
Expand Down Expand Up @@ -281,14 +308,19 @@ void kernel_and_merge<false, false, DequantizeFloat>::run(
offset_bias = bias + n_0;
}

strat.kernel(// A and B pointers are just the packed panels.
a_ptr, b_panel,
// Provide relevant part of output array and row stride.
c_ptr ? (c_ptr + m_0 * ldc + n_0) : nullptr, ldc,
// M, N, K sizes
m_max-m_0, n_max - n_0, kern_k,
// Bias, activation, accumulation. Need to offset the bias as needed.
offset_col_bias, dq, offset_bias, act, accumulate, acc_buff);
// When b_offset != 0, row sums of A are packed at the end of the A panel
// (appended by the quantized PrepareA transform with multiplier=1). Read them
// to pass to dequantize_block_32 for per-row offset correction.
const int32_t *row_sum = nullptr;
if (dq.b_offset != 0) {
row_sum = reinterpret_cast<const int32_t *>(a_ptr + strategy::out_height() * kern_k);
}

run_dequantized_integrated_kernel(
strat, a_ptr, b_panel,
c_ptr ? (c_ptr + m_0 * ldc + n_0) : nullptr, ldc,
m_max - m_0, n_max - n_0, kern_k, offset_col_bias, dq, offset_bias, act,
accumulate, acc_buff, row_sum, 0);
}

template<>
Expand All @@ -300,7 +332,7 @@ void kernel_and_merge<true, false, DequantizeFloat>::run(
strategy &strat, const Tlo *a_ptr, const Tro *b_panel, size_t, Tri *c_panel,
Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *bias,
const Activation &act, bool not_first_pass, const DequantizeFloat &qp, const int32_t *,
const Activation &act, bool not_first_pass, const DequantizeFloat &qp, const int32_t *col_bias,
Tab *)
{
const int bblocks = iceildiv(n_max - n_0, strategy::out_width());
Expand All @@ -317,14 +349,20 @@ void kernel_and_merge<true, false, DequantizeFloat>::run(
#ifdef CYCLE_PROFILING
auto p=prof.ScopedProfiler(PROFILE_QUANTIZE, ((m_max-m_0) * bblocks * strategy::out_width() * sizeof(Tr)));
#endif
// When b_offset != 0, row sums are packed after the A panel data
const int32_t *row_sum = (qp.b_offset != 0)
? reinterpret_cast<const int32_t *>(a_ptr + strategy::out_height() * kern_k)
: nullptr;

for (int i=0; i<bblocks; i++) {
unsigned int n_start = n_0 + (strategy::out_width() * i);
unsigned int n_end = std::min(n_start + strategy::out_width(), n_max);

dequantize_block_32(qp, (n_end - n_start), (m_max - m_0),
c_panel + (i * strategy::out_width() * strategy::out_height()), strategy::out_width(),
c_ptr + m_0 * ldc + n_start, ldc,
bias != nullptr ? bias + n_start : nullptr, not_first_pass, act);
bias != nullptr ? bias + n_start : nullptr, not_first_pass, act,
col_bias != nullptr ? col_bias + n_start : nullptr, row_sum, kern_k);

}
}
Expand Down Expand Up @@ -475,6 +513,13 @@ class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> {
return _Nsize * _nmulti * sizeof(int32_t);
}

if (std::is_same<OutputStage, DequantizeFloat>::value) {
const DequantizeFloat *dq = reinterpret_cast<const DequantizeFloat *>(&_os);
if (dq->a_offset != 0) {
return _Nsize * _nmulti * sizeof(int32_t);
}
}

return 0;
}

Expand Down Expand Up @@ -557,6 +602,12 @@ class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> {
k_depth += sizeof(int32_t) / sizeof(Tloi);
}

if (std::is_same<OutputStage, DequantizeFloat>::value && MergeStep) {
// transforms_quantized always packs row sum slots (zeros when multiplier=0, actual
// sums when b_offset != 0). Reserve space unconditionally when MergeStep is enabled.
k_depth += sizeof(int32_t) / sizeof(Tloi);
}

return k_depth;
}

Expand Down Expand Up @@ -647,6 +698,13 @@ class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> {
return -qp->b_offset;
}

if (std::is_same<OutputStage, DequantizeFloat>::value) {
const DequantizeFloat *dq = reinterpret_cast<const DequantizeFloat *>(&_os);
// Pack row sums into the A panel when b_offset is non-zero so that the
// merge step can apply the b_offset correction per output position.
return (dq->b_offset != 0) ? 1 : 0;
}

return 0;
}

Expand Down Expand Up @@ -693,6 +751,14 @@ class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> {
return get_ktotal(args);
}

// K blocking is not supported for DequantizeFloat with MergeStep when b_offset != 0,
// because row sums of A must cover the full K depth. We cannot check b_offset here
// (static function), so we conservatively disable K-blocking for all DequantizeFloat
// MergeStep cases. The working-memory cost is minimal and correctness is guaranteed.
if (std::is_same<OutputStage, DequantizeFloat>::value && MergeStep) {
return get_ktotal(args);
}

// We can't K block non-fast FP16 cases without an accumulation buffer.
#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16))
if (std::is_same<Tlo, __fp16>::value && std::is_same<Tr, __fp16>::value && !args._fast_mode && MergeStep) {
Expand Down Expand Up @@ -937,7 +1003,7 @@ class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> {
#endif
// See comment above on transform_type<> class: this extracts either 'transforms' or
// 'transforms_quantized' as appropriate.
typename transform_type<strategy, MergeStep && std::is_same<OutputStage, Requantize32>::value>::type transforms;
typename transform_type<strategy, MergeStep && (std::is_same<OutputStage, Requantize32>::value || std::is_same<OutputStage, DequantizeFloat>::value)>::type transforms;

if (_indirect_buf != nullptr) {
transforms.PrepareA_indirect(a_panel,
Expand Down Expand Up @@ -1027,7 +1093,7 @@ class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> {
#endif
// See comment above on transform_type<> class: this extracts either 'transforms' or
// 'transforms_quantized' as appropriate.
typename transform_type<strategy, MergeStep && std::is_same<OutputStage, Requantize32>::value>::type transforms;
typename transform_type<strategy, MergeStep && (std::is_same<OutputStage, Requantize32>::value || std::is_same<OutputStage, DequantizeFloat>::value)>::type transforms;

for (unsigned int batch = batch_0; batch <= batch_end; batch++) {
unsigned int first_m = (batch == batch_0) ? m_0 : 0;
Expand Down Expand Up @@ -1060,6 +1126,10 @@ class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> {

if(std::is_same<OutputStage, Requantize32>::value) {
a_panel_stride = kern_k + (sizeof(int32_t) / sizeof(Tloi));
} else if (std::is_same<OutputStage, DequantizeFloat>::value && MergeStep) {
// transforms_quantized always packs row-sum slots (zeros when b_offset=0,
// actual sums when b_offset != 0), so the stride must include the slot.
a_panel_stride = kern_k + (sizeof(int32_t) / sizeof(Tloi));
} else {
a_panel_stride = kern_k;
}
Expand Down Expand Up @@ -1212,6 +1282,20 @@ class GemmInterleaved : public GemmCommon<Tlo, Tro, Tr> {
compute_col_sums(*qp_ptr, _Nsize, _Ksize * _Ksections, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize * _Ksections, i, 0);
}
}

if (std::is_same<OutputStage, DequantizeFloat>::value) {
const DequantizeFloat *dq = reinterpret_cast<const DequantizeFloat *>(&_os);
if (dq->a_offset != 0) {
// Compute raw column sums of B (weight matrix) for use in a_offset correction.
// dequantize_block_32 applies: -a_offset * col_sums[n] * scale per output channel.
col_bias = reinterpret_cast<int32_t *>(in_buffer);
for (unsigned int i = 0; i < _nmulti; ++i) {
compute_raw_col_sums(_Nsize, _Ksize * _Ksections,
B + (i * B_multi_stride), ldb,
col_bias + (i * _Nsize));
}
}
}
}

// Support for transposed B is a property of the strategy::transpose type
Expand Down Expand Up @@ -1431,4 +1515,3 @@ template<typename strategy, typename Tlo, typename Tro, typename Tr>
using GemmInterleavedDequantized = GemmInterleaved<strategy, Tlo, Tro, Tr, DequantizeFloat>;

} // namespace arm_gemm

42 changes: 35 additions & 7 deletions src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,48 +49,77 @@

namespace arm_gemm {

#if defined(ARM_COMPUTE_ENABLE_SME2) || defined(ARM_COMPUTE_ENABLE_SME)
namespace {

bool supports_symmetric_dequant_no_merge(const GemmArgs &args, const DequantizeFloat &dq)
{
return !args._accumulate && dq.a_offset == 0 && dq.b_offset == 0;
}

} // namespace
#endif // defined(ARM_COMPUTE_ENABLE_SME2) || defined(ARM_COMPUTE_ENABLE_SME)

static const GemmImplementation<int8_t, int8_t, float, DequantizeFloat> gemm_s8fp32_methods[] =
{
#ifdef ARM_COMPUTE_ENABLE_SME2
{
"sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL",
[](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && args._ci->has_sme_i8i32() && !args._accumulate; },
[](const GemmArgs &args, const DequantizeFloat &dq) {
return args._ci->has_sme2() && args._ci->has_sme_i8i32() &&
supports_symmetric_dequant_no_merge(args, dq);
},
[](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>();
return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
[](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL, int8_t, int8_t, float>(args, dq); }
},
{
"sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL",
[](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && args._ci->has_sme_i8i32() && !args._accumulate; },
[](const GemmArgs &args, const DequantizeFloat &dq) {
return args._ci->has_sme2() && args._ci->has_sme_i8i32() &&
supports_symmetric_dequant_no_merge(args, dq);
},
[](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>();
return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
[](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL, int8_t, int8_t, float>(args, dq); }
},
{
"sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL",
[](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && args._ci->has_sme_i8i32() && !args._accumulate; },
[](const GemmArgs &args, const DequantizeFloat &dq) {
return args._ci->has_sme2() && args._ci->has_sme_i8i32() &&
supports_symmetric_dequant_no_merge(args, dq);
},
nullptr,
[](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL, int8_t, int8_t, float>(args, dq); }
},
#endif // ARM_COMPUTE_ENABLE_SME2
#ifdef ARM_COMPUTE_ENABLE_SME
{
"sme_interleaved_nomerge_s8qfp32_mopa_1VLx4VL",
[](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme() && args._ci->has_sme_i8i32() && !args._accumulate; },
[](const GemmArgs &args, const DequantizeFloat &dq) {
return args._ci->has_sme() && args._ci->has_sme_i8i32() &&
supports_symmetric_dequant_no_merge(args, dq);
},
[](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>();
return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
[](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme_interleaved_nomerge_s8qfp32_mopa_1VLx4VL, int8_t, int8_t, float>(args, dq); }
},
{
"sme_interleaved_nomerge_s8qfp32_mopa_4VLx1VL",
[](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme() && args._ci->has_sme_i8i32() && !args._accumulate; },
[](const GemmArgs &args, const DequantizeFloat &dq) {
return args._ci->has_sme() && args._ci->has_sme_i8i32() &&
supports_symmetric_dequant_no_merge(args, dq);
},
[](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>();
return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
[](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme_interleaved_nomerge_s8qfp32_mopa_4VLx1VL, int8_t, int8_t, float>(args, dq); }
},
{
"sme_interleaved_nomerge_s8qfp32_mopa_2VLx2VL",
[](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme() && args._ci->has_sme_i8i32() && !args._accumulate; },
[](const GemmArgs &args, const DequantizeFloat &dq) {
return args._ci->has_sme() && args._ci->has_sme_i8i32() &&
supports_symmetric_dequant_no_merge(args, dq);
},
nullptr,
[](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme_interleaved_nomerge_s8qfp32_mopa_2VLx2VL, int8_t, int8_t, float>(args, dq); }
},
Expand Down Expand Up @@ -153,4 +182,3 @@ template std::vector<KernelDescription> get_compatible_kernels<int8_t, int8_t, f
} // namespace arm_gemm

#endif // __aarch64__

3 changes: 2 additions & 1 deletion src/core/NEON/kernels/arm_gemm/quantized-fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ namespace arm_gemm {
template<>
void dequantize_block_32<__fp16>(const DequantizeFloat &qp, unsigned int width, unsigned int height,
const int32_t * in_ptr, unsigned int in_stride, __fp16 *out_ptr, unsigned int out_stride,
const __fp16 * bias_ptr, bool not_first_pass, const Activation &act)
const __fp16 * bias_ptr, bool not_first_pass, const Activation &act,
const int32_t * /*col_bias*/, const int32_t * /*row_sum*/, int32_t /*k_total*/)
{
const float32x4_t vscale = vdupq_n_f32(qp.scale);
float maxval = std::numeric_limits<float>::infinity();
Expand Down
Loading