Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions infini_train/include/autograd/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ class Tensor;

namespace infini_train::autograd {

struct LinearGradFlags {
bool input = false;
bool weight = false;
bool bias = false;
};

class Linear : public Function {
public:
static constexpr char kType[] = "LinearFunction";
Expand Down
96 changes: 96 additions & 0 deletions infini_train/include/common/cuda/gemm.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#pragma once

#include <cublas_v2.h>
#include <cuda_runtime_api.h>

#include "infini_train/include/datatype.h"
#include "infini_train/include/device.h"

namespace infini_train::kernels::cuda {

/**
* Return the cuBLAS handle associated with the given device.
* Shared by linear.cu, matmul.cu, and any future GEMM-using kernels.
*/
cublasHandle_t GetCublasHandle(const Device &device);

/**
* Return the CUDA stream associated with the given device.
* Shared by kernels that need to launch device-side code directly.
*/
cudaStream_t GetCudaStream(const Device &device);

/**
* Parameter bundle for a single GEMM call:
* C = alpha * op(A) * op(B) + beta * C
*
* batch_count == 1 → non-batched path (cublasGemmEx)
* batch_count > 1 → strided-batched (cublasGemmStridedBatchedEx)
*
* When batch_count == 1, stride_a/b/c are unused and must be left at 0.
*/
struct GemmParams {
cublasOperation_t trans_a = CUBLAS_OP_N;
cublasOperation_t trans_b = CUBLAS_OP_N;

int m = 0; // rows of op(A) and C
int n = 0; // cols of op(B) and C
int k = 0; // cols of op(A) == rows of op(B)

const void *A = nullptr;
int lda = 0;
const void *B = nullptr;
int ldb = 0;
void *C = nullptr;
int ldc = 0;

float alpha = 1.0f;
float beta = 0.0f;

// batch_count=1: non-batched (Linear path); stride_a/b/c must be 0
// batch_count>1: strided-batched (Matmul path)
int batch_count = 1;
long long stride_a = 0;
long long stride_b = 0;
long long stride_c = 0;

DataType input_dtype; // dtype of A and B
DataType output_dtype; // dtype of C (may differ, e.g. bf16 in → fp32 out)

cublasHandle_t blas_handle = nullptr;
};

/**
* Execute the GEMM described by `p` via cuBLAS.
* Dispatches to cublasGemmEx (batch_count==1) or
* cublasGemmStridedBatchedEx (batch_count>1).
* Uses CUBLAS_COMPUTE_32F for all input dtypes to ensure precision.
* Aborts on cuBLAS error (via CUBLAS_CHECK / LOG(FATAL)).
*/
void GemmCuda(const GemmParams &p);

/**
* Parameter bundle for a single SGEMV call (fp32 only):
* y = alpha * op(A) * x + beta * y
*
* op(A) is m_phys-by-n_phys when trans==N, or n_phys-by-m_phys when trans==T,
* where m_phys and n_phys are the physical (pre-transpose) row/col counts of A.
*/
struct SgemvParams {
cublasOperation_t trans = CUBLAS_OP_N;
int m = 0;
int n = 0;
const float *A = nullptr;
int lda = 0;
const float *x = nullptr;
int incx = 1;
float *y = nullptr;
int incy = 1;
float alpha = 1.0f;
float beta = 0.0f;
cublasHandle_t blas_handle = nullptr;
};

void SgemvCuda(const SgemvParams &p);

} // namespace infini_train::kernels::cuda
30 changes: 21 additions & 9 deletions infini_train/src/autograd/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,29 @@ std::vector<std::shared_ptr<Tensor>> Linear::Backward(const std::vector<std::sha
const auto &grad_output = grad_outputs[0];

CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Linear::Backward";
LinearGradFlags grad_flags = {.input = needs_input_grad_[0],
.weight = needs_input_grad_.size() > 1 && needs_input_grad_[1],
.bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]};
bool need_grad_input = needs_input_grad_[0];
bool need_grad_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1];
bool need_grad_bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2];

auto device = grad_output->GetDevice().type();
// TODO: skip autograd graph construction entirely when no input requires grad
auto [grad_input, grad_weight, grad_bias]
= Dispatcher::Instance()
.Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
{device, "LinearBackward"}, input, weight, transpose_, in_features_, out_features_, input_dims_,
grad_output, bias_, grad_flags);

std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_weight = nullptr;
std::shared_ptr<Tensor> grad_bias = nullptr;

if (need_grad_input) {
grad_input = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>(
{device, "LinearBackwardInput"}, weight, grad_output, transpose_, in_features_, out_features_, input_dims_);
}
if (need_grad_weight) {
grad_weight = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>(
{device, "LinearBackwardWeight"}, input, grad_output, transpose_, in_features_, out_features_);
}
if (need_grad_bias) {
grad_bias = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "LinearBackwardBias"}, grad_output,
out_features_);
}

if (bias_) {
return {grad_input, grad_weight, grad_bias};
} else {
Expand Down
35 changes: 28 additions & 7 deletions infini_train/src/autograd/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,17 @@ void Matmul::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tens
// FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be
// determined by autocast, not derived from output->Dtype().
auto compute_dtype = output->Dtype();
saved_tensors_ = {
input1->Dtype() == compute_dtype ? input1 : std::make_shared<Tensor>(input1->To(compute_dtype)),
input2->Dtype() == compute_dtype ? input2 : std::make_shared<Tensor>(input2->To(compute_dtype)),

// grad_input1 = grad_output @ input2^T, so input2 is needed
// grad_input2 = grad_output^T @ input1, so input1 is needed
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];

auto cast = [&](const std::shared_ptr<Tensor> &t) {
return t->Dtype() == compute_dtype ? t : std::make_shared<Tensor>(t->To(compute_dtype));
};

saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr};
out_features_ = output->Dims()[0];
}

Expand All @@ -45,10 +52,24 @@ std::vector<std::shared_ptr<Tensor>> Matmul::Backward(const std::vector<std::sha
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Matmul::Backward";
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];

auto device = input1->GetDevice().type();
auto [grad_input1, grad_input2]
= Dispatcher::Instance().Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
{device, "MatmulBackward"}, input1, input2, grad_output);
return {grad_input1, grad_input2};

std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_other = nullptr;

if (need_grad_input1) {
grad_input = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardInput"}, input2,
grad_output, input1->Dims());
}
if (need_grad_input2) {
grad_other = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardOther"}, input1,
grad_output, input2->Dims());
}

return {grad_input, grad_other};
}
} // namespace infini_train::autograd
Loading
Loading