Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions infini_train/include/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>

Expand All @@ -13,6 +14,34 @@ template <typename HookType> class HookHandleImpl;

namespace infini_train::autograd {

class FunctionCtx {
public:
FunctionCtx();

void SaveForBackward(const std::vector<std::shared_ptr<Tensor>> &tensors);

const std::vector<std::shared_ptr<Tensor>> &saved_tensors() const;

void MarkNonDifferentiable(const std::vector<std::shared_ptr<Tensor>> &outputs);

const std::vector<bool> &needs_input_grad() const;

private:
friend class Function;

void set_needs_input_grad(std::vector<bool> needs_input_grad);

void SaveVariables(const std::vector<std::shared_ptr<Tensor>> &outputs);
void ReleaseVariables();

bool IsNonDifferentiable(const std::shared_ptr<Tensor> &output) const;

std::vector<std::shared_ptr<Tensor>> to_save_;
std::vector<std::shared_ptr<Tensor>> saved_tensors_;
std::vector<bool> needs_input_grad_;
std::vector<Tensor *> non_differentiable_;
};

class Function : public std::enable_shared_from_this<Function> {
public:
template <typename HookType> using FunctionHookHandleImpl = infini_train::HookHandleImpl<HookType>;
Expand All @@ -23,14 +52,14 @@ class Function : public std::enable_shared_from_this<Function> {

static constexpr char kUndefinedType[] = "Undefined";

Function() : type_(kUndefinedType) {}
explicit Function(const std::string &type) : type_(type) {}
Function();
explicit Function(const std::string &type);

virtual ~Function() = default;
virtual ~Function();

virtual std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) = 0;
virtual void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) {}
const std::vector<std::shared_ptr<Tensor>> &output_tensors);
virtual std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) = 0;

std::vector<std::shared_ptr<Tensor>> Apply(const std::vector<std::shared_ptr<Tensor>> &input_tensors);
Expand All @@ -43,11 +72,10 @@ class Function : public std::enable_shared_from_this<Function> {
std::shared_ptr<infini_train::HookHandle> RegisterBackwardPreHook(FunctionPreHook hook);
std::shared_ptr<infini_train::HookHandle> RegisterBackwardPostHook(FunctionPostHook hook);

const std::string &type() const { return type_; }
const std::string &type() const;

protected:
std::vector<std::shared_ptr<Tensor>> saved_tensors_;
std::vector<bool> needs_input_grad_;
FunctionCtx ctx_;

private:
std::vector<std::pair<std::shared_ptr<Function>, int>> next_functions_;
Expand Down
6 changes: 3 additions & 3 deletions infini_train/src/autograd/activations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ std::vector<std::shared_ptr<Tensor>> Sigmoid::Forward(const std::vector<std::sha
void Sigmoid::SetupContext(const std::vector<std::shared_ptr<Tensor>> &,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) {
const auto &output = output_tensors[0];
saved_tensors_ = {output};
ctx_.SaveForBackward({output});
}

std::vector<std::shared_ptr<Tensor>> Sigmoid::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(saved_tensors_.size(), 1);
const auto &output = saved_tensors_[0];
CHECK_EQ(ctx_.saved_tensors().size(), 1);
const auto &output = ctx_.saved_tensors()[0];
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

Expand Down
58 changes: 29 additions & 29 deletions infini_train/src/autograd/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ std::vector<std::shared_ptr<Tensor>> Reciprocal::Forward(const std::vector<std::
void Reciprocal::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &) {
const auto &input = input_tensors[0];
saved_tensors_ = {input};
ctx_.SaveForBackward({input});
}

std::vector<std::shared_ptr<Tensor>> Reciprocal::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(saved_tensors_.size(), 1);
const auto &input = saved_tensors_[0];
CHECK_EQ(ctx_.saved_tensors().size(), 1);
const auto &input = ctx_.saved_tensors()[0];
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

Expand All @@ -57,12 +57,12 @@ std::vector<std::shared_ptr<Tensor>> Sin::Forward(const std::vector<std::shared_
void Sin::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &) {
const auto &input = input_tensors[0];
saved_tensors_ = {input};
ctx_.SaveForBackward({input});
}

std::vector<std::shared_ptr<Tensor>> Sin::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(saved_tensors_.size(), 1);
const auto &input = saved_tensors_[0];
CHECK_EQ(ctx_.saved_tensors().size(), 1);
const auto &input = ctx_.saved_tensors()[0];
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

Expand All @@ -81,12 +81,12 @@ std::vector<std::shared_ptr<Tensor>> Cos::Forward(const std::vector<std::shared_
void Cos::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &) {
const auto &input = input_tensors[0];
saved_tensors_ = {input};
ctx_.SaveForBackward({input});
}

std::vector<std::shared_ptr<Tensor>> Cos::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(saved_tensors_.size(), 1);
const auto &input = saved_tensors_[0];
CHECK_EQ(ctx_.saved_tensors().size(), 1);
const auto &input = ctx_.saved_tensors()[0];
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

Expand All @@ -105,12 +105,12 @@ std::vector<std::shared_ptr<Tensor>> Tanh::Forward(const std::vector<std::shared
void Tanh::SetupContext(const std::vector<std::shared_ptr<Tensor>> &,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) {
const auto &output = output_tensors[0];
saved_tensors_ = {output};
ctx_.SaveForBackward({output});
}

std::vector<std::shared_ptr<Tensor>> Tanh::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(saved_tensors_.size(), 1);
const auto &output = saved_tensors_[0];
CHECK_EQ(ctx_.saved_tensors().size(), 1);
const auto &output = ctx_.saved_tensors()[0];
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

Expand All @@ -130,12 +130,12 @@ std::vector<std::shared_ptr<Tensor>> Pow::Forward(const std::vector<std::shared_
void Pow::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &) {
const auto &input = input_tensors[0];
saved_tensors_ = {input};
ctx_.SaveForBackward({input});
}

std::vector<std::shared_ptr<Tensor>> Pow::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(saved_tensors_.size(), 1);
const auto &input = saved_tensors_[0];
CHECK_EQ(ctx_.saved_tensors().size(), 1);
const auto &input = ctx_.saved_tensors()[0];
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

Expand All @@ -155,12 +155,12 @@ std::vector<std::shared_ptr<Tensor>> Rsqrt::Forward(const std::vector<std::share
void Rsqrt::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &) {
const auto &input = input_tensors[0];
saved_tensors_ = {input};
ctx_.SaveForBackward({input});
}

std::vector<std::shared_ptr<Tensor>> Rsqrt::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(saved_tensors_.size(), 1);
const auto &input = saved_tensors_[0];
CHECK_EQ(ctx_.saved_tensors().size(), 1);
const auto &input = ctx_.saved_tensors()[0];
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

Expand Down Expand Up @@ -195,12 +195,12 @@ std::vector<std::shared_ptr<Tensor>> Log::Forward(const std::vector<std::shared_
void Log::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &) {
const auto &input = input_tensors[0];
saved_tensors_ = {input};
ctx_.SaveForBackward({input});
}

std::vector<std::shared_ptr<Tensor>> Log::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(saved_tensors_.size(), 1);
const auto &input = saved_tensors_[0];
CHECK_EQ(ctx_.saved_tensors().size(), 1);
const auto &input = ctx_.saved_tensors()[0];
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

Expand Down Expand Up @@ -455,13 +455,13 @@ void Mul::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors
const std::vector<std::shared_ptr<Tensor>> &) {
const auto &a = input_tensors[0];
const auto &b = input_tensors[1];
saved_tensors_ = {a, b};
ctx_.SaveForBackward({a, b});
}

std::vector<std::shared_ptr<Tensor>> Mul::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(saved_tensors_.size(), 2);
const auto &a = saved_tensors_[0];
const auto &b = saved_tensors_[1];
CHECK_EQ(ctx_.saved_tensors().size(), 2);
const auto &a = ctx_.saved_tensors()[0];
const auto &b = ctx_.saved_tensors()[1];
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

Expand Down Expand Up @@ -500,13 +500,13 @@ void Div::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors
const std::vector<std::shared_ptr<Tensor>> &) {
const auto &a = input_tensors[0];
const auto &b = input_tensors[1];
saved_tensors_ = {a, b};
ctx_.SaveForBackward({a, b});
}

std::vector<std::shared_ptr<Tensor>> Div::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(saved_tensors_.size(), 2);
const auto &a = saved_tensors_[0];
const auto &b = saved_tensors_[1];
CHECK_EQ(ctx_.saved_tensors().size(), 2);
const auto &a = ctx_.saved_tensors()[0];
const auto &b = ctx_.saved_tensors()[1];
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

Expand Down
Loading
Loading