Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 159 additions & 1 deletion src/native/cambricon/common.h
Comment thread
bitzyz marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace infini::ops::reduce {

constexpr int batch_size = 128 / sizeof(float);

__mlu_func__ void SumInternal(float* dst, float* src, int max_batch) {
__mlu_func__ void SumInternal(float* src, float* dst, int max_batch) {
const int width = max_batch / batch_size;

if (width >= 4) {
Expand All @@ -30,6 +30,164 @@ __mlu_func__ void SumInternal(float* dst, float* src, int max_batch) {
}
}

template <typename T>
__mlu_func__ void SumTyped(T* data, float* result, size_t len) {
if constexpr (std::is_same_v<T, __half>) {
__bang_half2float((float*)data, reinterpret_cast<half*>(data) + len, len);
SumInternal((float*)data, result, len);
} else if constexpr (std::is_same_v<T, __bang_bfloat16>) {
__bang_bfloat162float((float*)data, data + len, len);
SumInternal((float*)data, result, len);
} else {
SumInternal(data, result, len);
}
}

template <typename T>
__mlu_func__ float Sum(const T* source, T* src, float* dst, int num_elements,
int max_batch) {
float res = 0.0f;
int offset = (sizeof(T) == 2 ? max_batch : 0);

size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);

if (curr_batch < max_batch) {
__bang_write_value(src, max_batch + offset, 0);
}

__memcpy(src + offset, source + processed, curr_batch * sizeof(T),
GDRAM2NRAM);
SumTyped(src, dst, max_batch);
res += dst[0];
processed += curr_batch;
}

return res;
}

template <typename T>
__mlu_func__ float SumBatched(const T* source, T* src, float* dst,
int num_elements, int max_batch) {
constexpr int min_vector_size = 32;

if (num_elements < min_vector_size) {
return Sum(source, src, dst, num_elements, max_batch);
}

float res = 0.0f;
int offset = (sizeof(T) == 2 ? max_batch : 0);

size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
size_t aligned_batch = (curr_batch / batch_size) * batch_size;
size_t remainder = curr_batch % batch_size;

// Ensure NRAM buffer is zeroed.
__bang_write_value(src, max_batch + offset, 0);

// Copy data to NRAM.
__memcpy(src + offset, source + processed, curr_batch * sizeof(T),
GDRAM2NRAM);

if constexpr (std::is_same_v<T, __half>) {
__bang_half2float((float*)(src + offset),
reinterpret_cast<half*>(src) + offset, curr_batch);
} else if constexpr (std::is_same_v<T, __bang_bfloat16>) {
__bang_bfloat162float((float*)(src + offset), src + offset, curr_batch);
}

if (aligned_batch > 0) {
SumInternal((float*)(src + offset), dst, aligned_batch);
res += dst[0];
}
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
res += ((float*)(src + offset))[i];
}
}

processed += curr_batch;
}

return res;
}

__mlu_func__ void MaxInternal(float* src, float* dst, int max_batch) {
__bang_maxpool(dst, src, batch_size, 1, max_batch / batch_size, 1,
max_batch / batch_size, 1, 1);
__bang_argmax(dst, dst, batch_size);
}

template <typename T>
__mlu_func__ void MaxTyped(T* data, float* result, size_t len) {
if constexpr (std::is_same_v<T, __half>) {
__bang_half2float((float*)data, reinterpret_cast<half*>(data) + len, len);
MaxInternal((float*)data, result, len);
} else if constexpr (std::is_same_v<T, __bang_bfloat16>) {
__bang_bfloat162float((float*)data, data + len, len);
MaxInternal((float*)data, result, len);
} else {
MaxInternal(data, result, len);
}
}

template <typename T>
__mlu_func__ float Max(const T* source, T* src, float* dst, int num_elements,
int max_batch) {
float max_val = -INFINITY;
int offset = (sizeof(T) == 2 ? max_batch : 0);

size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);

if (curr_batch < max_batch) {
__bang_write_value(src, max_batch + offset, 0);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用 0 padding 会影响 Max/MaxBatched 的结果。若当前 batch 不满 max_batch,且真实输入全为负数,padding 的 0 会参与 max,使 max_val 被抬到 0。softmax 数学上减常数通常不改变结果,但在极端负值下可能让所有 exp(x - max) 下溢,导致 sum_val == 0 或精度明显变差。

建议 max reduction 的 padding 改成 -INFINITY,或者只对 curr_batch 的有效元素做 reduce;和 Sum 不同,Max 不能用 0 作为中性填充值。

}

__memcpy(src + offset, source + processed, curr_batch * sizeof(T),
GDRAM2NRAM);
MaxTyped(src, dst, max_batch);
max_val = std::max(max_val, dst[0]);
processed += curr_batch;
}

return max_val;
}

template <typename T>
__mlu_func__ float MaxBatched(const T* source, T* src, float* dst,
int num_elements, int max_batch) {
constexpr int min_vector_size = 32;

if (num_elements < min_vector_size) {
return Max(source, src, dst, num_elements, max_batch);
}

float max_val = -INFINITY;
int offset = (sizeof(T) == 2 ? max_batch : 0);

size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);

if (curr_batch < max_batch) {
__bang_write_value(src, max_batch + offset, 0);
}

__memcpy(src + offset, source + processed, curr_batch * sizeof(T),
GDRAM2NRAM);
MaxTyped(src, dst, max_batch);
max_val = std::max(max_val, dst[0]);
processed += curr_batch;
}

return max_val;
}

} // namespace infini::ops::reduce

#endif // __BANG__
Expand Down
63 changes: 63 additions & 0 deletions src/native/cambricon/ops/causal_softmax/causal_softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#ifndef INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H
#define INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H

#include "base/causal_softmax.h"
#include "native/cambricon/common.h"
#include "native/cambricon/data_type_.h"

namespace infini::ops {

// TODO: Remove forward declaration.
template <typename T>
void CausalSoftmaxUnion(void* workspace, int core_per_cluster,
int cluster_count, cnrtQueue_t queue, const void* x,
void* y, size_t batch_size_, size_t seq_len_,
size_t total_seq_len_, ptrdiff_t y_stride_b,
ptrdiff_t y_stride_i, ptrdiff_t y_stride_j,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i,
ptrdiff_t x_stride_j);

template <>
class Operator<CausalSoftmax, Device::Type::kCambricon> : public CausalSoftmax {
public:
Operator(const Tensor input, Tensor out) : CausalSoftmax{input, out} {
cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster,
&cluster_count);
}

void operator()(const Tensor input, Tensor out) const override {
Comment thread
bitzyz marked this conversation as resolved.
auto queue = static_cast<cnrtQueue_t>(stream_ ? stream_ : 0);
auto workspace{workspace_ ? workspace_ : default_workspace_};
ptrdiff_t y_stride_b = ndim_ == 3 ? out_strides_[0] : 1;
ptrdiff_t y_stride_i = ndim_ == 3 ? out_strides_[1] : out_strides_[0];
ptrdiff_t y_stride_j = ndim_ == 3 ? out_strides_[2] : out_strides_[1];
ptrdiff_t x_stride_b = ndim_ == 3 ? input_strides_[0] : 1;
ptrdiff_t x_stride_i = ndim_ == 3 ? input_strides_[1] : input_strides_[0];
ptrdiff_t x_stride_j = ndim_ == 3 ? input_strides_[2] : input_strides_[1];

DispatchFunc<
List<DataType::kFloat16, DataType::kBFloat16, DataType::kFloat32>>(
{static_cast<int64_t>(input.dtype())},
[&](auto input_tag) {
using InputT = infini::ops::TypeMapType<Device::Type::kCambricon,
ListGet<0>(input_tag)>;
CausalSoftmaxUnion<InputT>(
workspace, core_per_cluster, cluster_count, queue, input.data(),
out.data(), batch_size_, seq_len_, total_seq_len_, y_stride_b,
y_stride_i, y_stride_j, x_stride_b, x_stride_i, x_stride_j);
},
"CambriconCausalSoftmax::operator() - output dispatch");
}

std::size_t workspace_size_in_bytes() const override { return 0; }

~Operator() {}

void* default_workspace_{nullptr};
int core_per_cluster = 0;
int cluster_count = 0;
};

} // namespace infini::ops

#endif
Loading
Loading