Skip to content
6 changes: 0 additions & 6 deletions infini_train/include/autograd/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ class Tensor;

namespace infini_train::autograd {

struct LinearGradFlags {
bool input = false;
bool weight = false;
bool bias = false;
};

class Linear : public Function {
public:
static constexpr char kType[] = "LinearFunction";
Expand Down
2 changes: 2 additions & 0 deletions infini_train/include/autograd/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,7 @@ class Matmul : public Function {

private:
int64_t out_features_ = 0;
std::vector<int64_t> input1_dims_;
std::vector<int64_t> input2_dims_;
};
} // namespace infini_train::autograd
30 changes: 21 additions & 9 deletions infini_train/src/autograd/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,29 @@ std::vector<std::shared_ptr<Tensor>> Linear::Backward(const std::vector<std::sha
const auto &grad_output = grad_outputs[0];

CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Linear::Backward";
LinearGradFlags grad_flags = {.input = needs_input_grad_[0],
.weight = needs_input_grad_.size() > 1 && needs_input_grad_[1],
.bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]};
bool need_grad_input = needs_input_grad_[0];
bool need_grad_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1];
bool need_grad_bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2];

auto device = grad_output->GetDevice().type();
// TODO: skip autograd graph construction entirely when no input requires grad
auto [grad_input, grad_weight, grad_bias]
= Dispatcher::Instance()
.Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
{device, "LinearBackward"}, input, weight, transpose_, in_features_, out_features_, input_dims_,
grad_output, bias_, grad_flags);

std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_weight = nullptr;
std::shared_ptr<Tensor> grad_bias = nullptr;

if (need_grad_input) {
grad_input = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>(
{device, "LinearBackwardInput"}, weight, grad_output, transpose_, in_features_, out_features_, input_dims_);
}
if (need_grad_weight) {
grad_weight = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>(
{device, "LinearBackwardWeight"}, input, grad_output, transpose_, in_features_, out_features_);
}
if (need_grad_bias) {
grad_bias = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "LinearBackwardBias"}, grad_output,
out_features_);
}

if (bias_) {
return {grad_input, grad_weight, grad_bias};
} else {
Expand Down
41 changes: 33 additions & 8 deletions infini_train/src/autograd/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,19 @@ void Matmul::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tens
// FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be
// determined by autocast, not derived from output->Dtype().
auto compute_dtype = output->Dtype();
saved_tensors_ = {
input1->Dtype() == compute_dtype ? input1 : std::make_shared<Tensor>(input1->To(compute_dtype)),
input2->Dtype() == compute_dtype ? input2 : std::make_shared<Tensor>(input2->To(compute_dtype)),

// grad_input1 = grad_output @ input2^T, so input2 is needed
// grad_input2 = grad_output^T @ input1, so input1 is needed
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];

auto cast = [&](const std::shared_ptr<Tensor> &t) {
return t->Dtype() == compute_dtype ? t : std::make_shared<Tensor>(t->To(compute_dtype));
};

saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr};
Comment thread
Chamberlain0w0 marked this conversation as resolved.
input1_dims_ = input1->Dims();
input2_dims_ = input2->Dims();
out_features_ = output->Dims()[0];
}

Expand All @@ -45,10 +54,26 @@ std::vector<std::shared_ptr<Tensor>> Matmul::Backward(const std::vector<std::sha
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

auto device = input1->GetDevice().type();
auto [grad_input1, grad_input2]
= Dispatcher::Instance().Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
{device, "MatmulBackward"}, input1, input2, grad_output);
return {grad_input1, grad_input2};
CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Matmul::Backward";
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];

auto device = grad_output->GetDevice().type();

std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_other = nullptr;

if (need_grad_input1) {
CHECK(input2 != nullptr) << "input2 not saved but need_grad_input1 is true";
grad_input = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardInput"}, input2,
grad_output, input1_dims_);
}
if (need_grad_input2) {
CHECK(input1 != nullptr) << "input1 not saved but need_grad_input2 is true";
grad_other = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardOther"}, input1,
grad_output, input2_dims_);
}

return {grad_input, grad_other};
}
} // namespace infini_train::autograd
179 changes: 40 additions & 139 deletions infini_train/src/kernels/cpu/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,103 +5,10 @@

#include "glog/logging.h"

#include "infini_train/include/autograd/linear.h"
#include "infini_train/include/dispatcher.h"
#include "infini_train/include/tensor.h"

namespace infini_train::kernels::cpu {
std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other) {
/*
output[*, m, n] = input[*, m, k] * other[*, k, n]
*/
// TODO(dcj): support broadcast later
const auto &input_dims = input->Dims();
const auto &other_dims = other->Dims();

CHECK_GE(input_dims.size(), 2);
CHECK_GE(other_dims.size(), 2);
CHECK_EQ(input_dims.size(), other_dims.size());

const int64_t m = input_dims[input_dims.size() - 2];
const int64_t k = input_dims[input_dims.size() - 1];
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
const int64_t n = other_dims[other_dims.size() - 1];

const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies<int64_t>{});
for (int64_t i = 0; i < input_dims.size() - 2; ++i) {
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
}

std::vector<int64_t> output_dims = input_dims;
output_dims[output_dims.size() - 1] = n;
auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32);

for (int64_t b = 0; b < bs; ++b) {
for (int64_t i = 0; i < m; ++i) {
for (int64_t j = 0; j < n; ++j) {
float acc = 0.0f;
for (int64_t p = 0; p < k; ++p) {
acc += static_cast<const float *>(input->DataPtr())[b * m * k + i * k + p]
* static_cast<const float *>(other->DataPtr())[b * k * n + p * n + j];
}
static_cast<float *>(output->DataPtr())[b * m * n + i * n + j] = acc;
}
}
}
return {output};
}

std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other,
const std::shared_ptr<Tensor> &grad_output) {
/*
grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T
grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n]
*/
const auto &input_dims = input->Dims();
const auto &other_dims = other->Dims();
const auto &grad_output_dims = grad_output->Dims();

CHECK_GE(input_dims.size(), 2);
CHECK_EQ(input_dims.size(), other_dims.size());
CHECK_EQ(input_dims.size(), grad_output_dims.size());

const int64_t m = input_dims[input_dims.size() - 2];
const int64_t k = input_dims[input_dims.size() - 1];
CHECK_EQ(k, other_dims[other_dims.size() - 2]);
const int64_t n = other_dims[other_dims.size() - 1];

CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]);
CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]);

const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies<int64_t>{});
for (int64_t i = 0; i < input_dims.size() - 2; ++i) {
CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match";
CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match";
}

auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
auto grad_other = std::make_shared<Tensor>(other_dims, DataType::kFLOAT32);
grad_input->Fill(0.0);
grad_other->Fill(0.0);

for (int64_t b = 0; b < bs; ++b) {
for (int64_t i = 0; i < m; ++i) {
for (int64_t j = 0; j < n; ++j) {
const float grad = static_cast<float *>(grad_output->DataPtr())[b * m * n + i * n + j];
for (int64_t p = 0; p < k; ++p) {
const auto input_idx = b * m * k + i * k + p;
const auto other_idx = b * k * n + p * n + j;
static_cast<float *>(grad_input->DataPtr())[input_idx]
+= grad * static_cast<const float *>(other->DataPtr())[other_idx];
static_cast<float *>(grad_other->DataPtr())[other_idx]
+= grad * static_cast<const float *>(input->DataPtr())[input_idx];
}
}
}
}
return {grad_input, grad_other};
}

std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight,
bool transpose, const std::shared_ptr<Tensor> &bias) {
Expand Down Expand Up @@ -146,71 +53,65 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
return output;
}

// TODO(dcj): support linear without bias later
std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight, bool transpose,
int64_t in_features, int64_t out_features, const std::vector<int64_t> &input_dims,
const std::shared_ptr<Tensor> &grad_output, bool bias,
infini_train::autograd::LinearGradFlags grad_flags) {
std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weight,
const std::shared_ptr<Tensor> &grad_output, bool transpose,
int64_t in_features, int64_t out_features,
const std::vector<int64_t> &input_dims) {
/*
transpose: grad_input = grad_output * weight
grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features]
grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)

!transpose: grad_input = grad_output * weight^T
grad_input[*, in_features] = grad_output[_, out_features] * weight[in_features, out_features]^T
grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
*/
const auto compute_grad_input = grad_flags.input;
const auto compute_grad_weight = grad_flags.weight;
const auto compute_grad_bias = grad_flags.bias;

CHECK_GE(input_dims.size(), 2);

std::vector<int64_t> weight_dims
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};

std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_weight = nullptr;
std::shared_ptr<Tensor> grad_bias = nullptr;

if (compute_grad_input) {
CHECK(weight != nullptr) << "compute_grad_input=true but weight is nullptr (selective save mismatch)";
grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
if (transpose) {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix();
} else {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose();
}
auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32);
if (transpose) {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix();
} else {
grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose();
}
return grad_input;
}

if (compute_grad_weight) {
CHECK(input != nullptr) << "compute_grad_weight=true but input is nullptr (selective save mismatch)";
grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32);
if (transpose) {
grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix();
} else {
grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix();
}
}
std::shared_ptr<Tensor> LinearBackwardWeight(const std::shared_ptr<Tensor> &input,
const std::shared_ptr<Tensor> &grad_output, bool transpose,
int64_t in_features, int64_t out_features) {
/*
transpose:
grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features]

if (compute_grad_bias && bias) {
grad_bias = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, DataType::kFLOAT32);
grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum();
!transpose:
grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features]
*/
std::vector<int64_t> weight_dims
= transpose ? std::vector<int64_t>{out_features, in_features} : std::vector<int64_t>{in_features, out_features};
auto grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32);
if (transpose) {
grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix();
} else {
grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix();
}
return grad_weight;
}

return {grad_input, grad_weight, grad_bias};
std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_output, int64_t out_features) {
/*
grad_bias[out_features] = grad_output[*, out_features].sum(axis=0)
*/
auto grad_bias = std::make_shared<Tensor>(std::vector<int64_t>{out_features}, DataType::kFLOAT32);
grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum();
return grad_bias;
}

} // namespace infini_train::kernels::cpu

#define REGISTER_CPU_LINEAR_KERNEL(kernel_name) \
REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name)

REGISTER_CPU_LINEAR_KERNEL(MatmulForward)
REGISTER_CPU_LINEAR_KERNEL(MatmulBackward)
REGISTER_CPU_LINEAR_KERNEL(LinearForward)
REGISTER_CPU_LINEAR_KERNEL(LinearBackward)
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardInput)
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardWeight)
REGISTER_CPU_LINEAR_KERNEL(LinearBackwardBias)

#undef REGISTER_CPU_LINEAR_KERNEL
Loading
Loading