diff --git a/infini_train/include/autograd/function.h b/infini_train/include/autograd/function.h index e1e1390e..791afbc6 100644 --- a/infini_train/include/autograd/function.h +++ b/infini_train/include/autograd/function.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -13,6 +14,34 @@ template class HookHandleImpl; namespace infini_train::autograd { +class FunctionCtx { +public: + FunctionCtx(); + + void SaveForBackward(const std::vector> &tensors); + + const std::vector> &saved_tensors() const; + + void MarkNonDifferentiable(const std::vector> &outputs); + + const std::vector &needs_input_grad() const; + +private: + friend class Function; + + void set_needs_input_grad(std::vector needs_input_grad); + + void SaveVariables(const std::vector> &outputs); + void ReleaseVariables(); + + bool IsNonDifferentiable(const std::shared_ptr &output) const; + + std::vector> to_save_; + std::vector> saved_tensors_; + std::vector needs_input_grad_; + std::vector non_differentiable_; +}; + class Function : public std::enable_shared_from_this { public: template using FunctionHookHandleImpl = infini_train::HookHandleImpl; @@ -23,14 +52,14 @@ class Function : public std::enable_shared_from_this { static constexpr char kUndefinedType[] = "Undefined"; - Function() : type_(kUndefinedType) {} - explicit Function(const std::string &type) : type_(type) {} + Function(); + explicit Function(const std::string &type); - virtual ~Function() = default; + virtual ~Function(); virtual std::vector> Forward(const std::vector> &input_tensors) = 0; virtual void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) {} + const std::vector> &output_tensors); virtual std::vector> Backward(const std::vector> &grad_outputs) = 0; std::vector> Apply(const std::vector> &input_tensors); @@ -43,11 +72,10 @@ class Function : public std::enable_shared_from_this { std::shared_ptr RegisterBackwardPreHook(FunctionPreHook hook); std::shared_ptr RegisterBackwardPostHook(FunctionPostHook hook); - const std::string &type() const { return type_; } + const std::string &type() const; protected: - std::vector> saved_tensors_; - std::vector needs_input_grad_; + FunctionCtx ctx_; private: std::vector, int>> next_functions_; diff --git a/infini_train/src/autograd/activations.cc b/infini_train/src/autograd/activations.cc index 3641865a..1115ad57 100644 --- a/infini_train/src/autograd/activations.cc +++ b/infini_train/src/autograd/activations.cc @@ -17,12 +17,12 @@ std::vector> Sigmoid::Forward(const std::vector> &, const std::vector> &output_tensors) { const auto &output = output_tensors[0]; - saved_tensors_ = {output}; + ctx_.SaveForBackward({output}); } std::vector> Sigmoid::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &output = saved_tensors_[0]; + CHECK_EQ(ctx_.saved_tensors().size(), 1); + const auto &output = ctx_.saved_tensors()[0]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/elementwise.cc b/infini_train/src/autograd/elementwise.cc index 655cd309..56d0aba9 100644 --- a/infini_train/src/autograd/elementwise.cc +++ b/infini_train/src/autograd/elementwise.cc @@ -33,12 +33,12 @@ std::vector> Reciprocal::Forward(const std::vector> &input_tensors, const std::vector> &) { const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + ctx_.SaveForBackward({input}); } std::vector> Reciprocal::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(ctx_.saved_tensors().size(), 1); + const auto &input = ctx_.saved_tensors()[0]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -57,12 +57,12 @@ std::vector> Sin::Forward(const std::vector> &input_tensors, const std::vector> &) { const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + ctx_.SaveForBackward({input}); } std::vector> Sin::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(ctx_.saved_tensors().size(), 1); + const auto &input = ctx_.saved_tensors()[0]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -81,12 +81,12 @@ std::vector> Cos::Forward(const std::vector> &input_tensors, const std::vector> &) { const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + ctx_.SaveForBackward({input}); } std::vector> Cos::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(ctx_.saved_tensors().size(), 1); + const auto &input = ctx_.saved_tensors()[0]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -105,12 +105,12 @@ std::vector> Tanh::Forward(const std::vector> &, const std::vector> &output_tensors) { const auto &output = output_tensors[0]; - saved_tensors_ = {output}; + ctx_.SaveForBackward({output}); } std::vector> Tanh::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &output = saved_tensors_[0]; + CHECK_EQ(ctx_.saved_tensors().size(), 1); + const auto &output = ctx_.saved_tensors()[0]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -130,12 +130,12 @@ std::vector> Pow::Forward(const std::vector> &input_tensors, const std::vector> &) { const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + ctx_.SaveForBackward({input}); } std::vector> Pow::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(ctx_.saved_tensors().size(), 1); + const auto &input = ctx_.saved_tensors()[0]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -155,12 +155,12 @@ std::vector> Rsqrt::Forward(const std::vector> &input_tensors, const std::vector> &) { const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + ctx_.SaveForBackward({input}); } std::vector> Rsqrt::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(ctx_.saved_tensors().size(), 1); + const auto &input = ctx_.saved_tensors()[0]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -195,12 +195,12 @@ std::vector> Log::Forward(const std::vector> &input_tensors, const std::vector> &) { const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + ctx_.SaveForBackward({input}); } std::vector> Log::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(ctx_.saved_tensors().size(), 1); + const auto &input = ctx_.saved_tensors()[0]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -455,13 +455,13 @@ void Mul::SetupContext(const std::vector> &input_tensors const std::vector> &) { const auto &a = input_tensors[0]; const auto &b = input_tensors[1]; - saved_tensors_ = {a, b}; + ctx_.SaveForBackward({a, b}); } std::vector> Mul::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 2); - const auto &a = saved_tensors_[0]; - const auto &b = saved_tensors_[1]; + CHECK_EQ(ctx_.saved_tensors().size(), 2); + const auto &a = ctx_.saved_tensors()[0]; + const auto &b = ctx_.saved_tensors()[1]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -500,13 +500,13 @@ void Div::SetupContext(const std::vector> &input_tensors const std::vector> &) { const auto &a = input_tensors[0]; const auto &b = input_tensors[1]; - saved_tensors_ = {a, b}; + ctx_.SaveForBackward({a, b}); } std::vector> Div::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 2); - const auto &a = saved_tensors_[0]; - const auto &b = saved_tensors_[1]; + CHECK_EQ(ctx_.saved_tensors().size(), 2); + const auto &a = ctx_.saved_tensors()[0]; + const auto &b = ctx_.saved_tensors()[1]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/function.cc b/infini_train/src/autograd/function.cc index a09d2004..f6ba6792 100644 --- a/infini_train/src/autograd/function.cc +++ b/infini_train/src/autograd/function.cc @@ -16,6 +16,83 @@ namespace infini_train::autograd { +namespace { +std::shared_ptr ShallowCopyWithoutAutogradMeta(const std::shared_ptr &tensor) { + if (!tensor) { + return nullptr; + } + return std::make_shared(*tensor, 0, tensor->Dims()); +} +} // namespace + +FunctionCtx::FunctionCtx() = default; + +const std::vector> &FunctionCtx::saved_tensors() const { return saved_tensors_; } + +const std::vector &FunctionCtx::needs_input_grad() const { return needs_input_grad_; } + +void FunctionCtx::SaveForBackward(const std::vector> &tensors) { to_save_ = tensors; } + +void FunctionCtx::MarkNonDifferentiable(const std::vector> &outputs) { + non_differentiable_.clear(); + non_differentiable_.reserve(outputs.size()); + for (const auto &output : outputs) { + if (output) { + non_differentiable_.push_back(output.get()); + } + } +} + +void FunctionCtx::set_needs_input_grad(std::vector needs_input_grad) { + needs_input_grad_ = std::move(needs_input_grad); +} + +void FunctionCtx::SaveVariables(const std::vector> &outputs) { + saved_tensors_.clear(); + saved_tensors_.reserve(to_save_.size()); + for (const auto &tensor : to_save_) { + bool is_output = false; + for (const auto &output : outputs) { + if (tensor && tensor.get() == output.get()) { + is_output = true; + break; + } + } + saved_tensors_.push_back(is_output ? ShallowCopyWithoutAutogradMeta(tensor) : tensor); + } + to_save_.clear(); +} + +void FunctionCtx::ReleaseVariables() { + to_save_.clear(); + saved_tensors_.clear(); + needs_input_grad_.clear(); + non_differentiable_.clear(); +} + +bool FunctionCtx::IsNonDifferentiable(const std::shared_ptr &output) const { + if (!output) { + return false; + } + for (const auto *non_differentiable : non_differentiable_) { + if (output.get() == non_differentiable) { + return true; + } + } + return false; +} + +Function::Function() : ctx_(), type_(kUndefinedType) {} + +Function::Function(const std::string &type) : ctx_(), type_(type) {} + +Function::~Function() = default; + +void Function::SetupContext(const std::vector> &, + const std::vector> &) {} + +const std::string &Function::type() const { return type_; } + std::vector> Function::Apply(const std::vector> &input_tensors) { CHECK_GE(input_tensors.size(), 1); auto device = input_tensors[0]->GetDevice(); @@ -37,14 +114,16 @@ std::vector> Function::Apply(const std::vector needs_input_grad(input_tensors.size()); for (size_t idx = 0; idx < input_tensors.size(); ++idx) { - needs_input_grad_[idx] = input_tensors[idx]->requires_grad(); + needs_input_grad[idx] = input_tensors[idx]->requires_grad(); } + ctx_.set_needs_input_grad(std::move(needs_input_grad)); } // Apply autocast once at the autograd boundary so Forward / SetupContext receive @@ -62,6 +141,8 @@ std::vector> Function::Apply(const std::vector> Function::Apply(const std::vectorrequires_grad(); } + + grad_outputs_reached_ = 0; + grad_outputs_.resize(output_tensors.size(), nullptr); + std::vector differentiable_outputs(output_tensors.size(), false); + bool has_differentiable_output = false; + for (int output_idx = 0; output_idx < output_tensors.size(); ++output_idx) { + differentiable_outputs[output_idx] + = output_requires_grad && !ctx_.IsNonDifferentiable(output_tensors[output_idx]); + has_differentiable_output |= differentiable_outputs[output_idx]; + if (!differentiable_outputs[output_idx]) { + ++grad_outputs_reached_; + } + } + + if (!has_differentiable_output) { + next_functions_.clear(); + return output_tensors; + } + for (int idx = 0; idx < input_tensors.size(); ++idx) { const auto &input_tensor = input_tensors[idx]; if (input_tensor->requires_grad() && input_tensor->is_leaf()) { @@ -86,18 +187,14 @@ std::vector> Function::Apply(const std::vectorgrad_fn()->IncreaseDependenciesNumber(); } } - output_requires_grad |= input_tensor->requires_grad(); } - grad_outputs_reached_ = 0; - grad_outputs_.resize(output_tensors.size(), nullptr); for (int output_idx = 0; output_idx < output_tensors.size(); ++output_idx) { auto &output_tensor = output_tensors[output_idx]; - // TODO(dcj): Mark if an output tensor need differentiable or not. - output_tensor->set_requires_grad(output_requires_grad); - output_tensor->set_grad_fn(output_requires_grad ? shared_from_this() : nullptr); - output_tensor->set_is_leaf(!output_requires_grad - || ((output_tensor->grad_fn() == nullptr) && output_requires_grad)); + const bool differentiable_output = differentiable_outputs[output_idx]; + output_tensor->set_requires_grad(differentiable_output); + output_tensor->set_grad_fn(differentiable_output ? shared_from_this() : nullptr); + output_tensor->set_is_leaf(!differentiable_output); output_tensor->set_output_idx(output_idx); } @@ -145,9 +242,8 @@ void Function::BackwardPartial(std::shared_ptr grad_output, int grad_out } } - saved_tensors_.clear(); + ctx_.ReleaseVariables(); grad_outputs_.clear(); - needs_input_grad_.clear(); grad_outputs_reached_ = 0; dependencies_reached_ = 0; diff --git a/infini_train/src/autograd/gather.cc b/infini_train/src/autograd/gather.cc index a30cb013..31dbd740 100644 --- a/infini_train/src/autograd/gather.cc +++ b/infini_train/src/autograd/gather.cc @@ -21,13 +21,13 @@ void Gather::SetupContext(const std::vector> &input_tens const auto &input = input_tensors[0]; const auto &index = input_tensors[1]; input_dims_ = input->Dims(); - saved_tensors_ = {index}; + ctx_.SaveForBackward({index}); } std::vector> Gather::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; - const auto &index = saved_tensors_[0]; + const auto &index = ctx_.saved_tensors()[0]; auto device = grad_outputs[0]->GetDevice(); auto kernel = Dispatcher::Instance().GetKernel({device.type(), "GatherBackward"}); diff --git a/infini_train/src/autograd/linear.cc b/infini_train/src/autograd/linear.cc index 76602b03..e7a8b28d 100644 --- a/infini_train/src/autograd/linear.cc +++ b/infini_train/src/autograd/linear.cc @@ -20,12 +20,26 @@ void Linear::SetupContext(const std::vector> &input_tens const std::vector> &output_tensors) { const auto &input = input_tensors[0]; const auto &weight = input_tensors[1]; + // Cast saved tensors to forward compute dtype (output dtype) so backward + // computes in the same precision as forward, matching PyTorch's behavior. - bool need_input = needs_input_grad_.size() > 0 && needs_input_grad_[0]; - bool need_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1]; + // FIXME: An extra cast (input/weight -> compute_dtype) is performed here because + // autocast runs before autograd. The correct approach is to adjust the ordering or + // integration of autocast and autograd so that autograd receives already-cast tensors, + // avoiding the redundant cast. + + // FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be + // determined by autocast, not derived from output_tensors[0]->Dtype(). + auto compute_dtype = output_tensors[0]->Dtype(); + bool need_input = ctx_.needs_input_grad().size() > 0 && ctx_.needs_input_grad()[0]; + bool need_weight = ctx_.needs_input_grad().size() > 1 && ctx_.needs_input_grad()[1]; + + auto cast = [&](const std::shared_ptr &t) { + return t->Dtype() == compute_dtype ? t : std::make_shared(t->To(compute_dtype)); + }; // grad_input needs weight, grad_weight needs input - saved_tensors_ = {need_weight ? input : nullptr, need_input ? weight : nullptr}; + ctx_.SaveForBackward({need_weight ? cast(input) : nullptr, need_input ? cast(weight) : nullptr}); transpose_ = true; bias_ = input_tensors.size() == 3; @@ -35,16 +49,16 @@ void Linear::SetupContext(const std::vector> &input_tens } std::vector> Linear::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 2); - const auto &input = saved_tensors_[0]; - const auto &weight = saved_tensors_[1]; + CHECK_EQ(ctx_.saved_tensors().size(), 2); + const auto &input = ctx_.saved_tensors()[0]; + const auto &weight = ctx_.saved_tensors()[1]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; - CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Linear::Backward"; - bool need_grad_input = needs_input_grad_[0]; - bool need_grad_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1]; - bool need_grad_bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]; + CHECK(!ctx_.needs_input_grad().empty()) << "needs_input_grad not populated in Linear::Backward"; + bool need_grad_input = ctx_.needs_input_grad()[0]; + bool need_grad_weight = ctx_.needs_input_grad().size() > 1 && ctx_.needs_input_grad()[1]; + bool need_grad_bias = bias_ && ctx_.needs_input_grad().size() > 2 && ctx_.needs_input_grad()[2]; auto device = grad_output->GetDevice().type(); diff --git a/infini_train/src/autograd/loss.cc b/infini_train/src/autograd/loss.cc index 657ea649..3c1fff48 100644 --- a/infini_train/src/autograd/loss.cc +++ b/infini_train/src/autograd/loss.cc @@ -19,13 +19,13 @@ void CrossEntropy::SetupContext(const std::vector> &inpu const std::vector> &) { const auto &input = input_tensors[0]; const auto &target = input_tensors[1]; - saved_tensors_ = {input, target}; + ctx_.SaveForBackward({input, target}); } std::vector> CrossEntropy::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 2); - const auto &input = saved_tensors_[0]; - const auto &target = saved_tensors_[1]; + CHECK_EQ(ctx_.saved_tensors().size(), 2); + const auto &input = ctx_.saved_tensors()[0]; + const auto &target = ctx_.saved_tensors()[1]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/matmul.cc b/infini_train/src/autograd/matmul.cc index 1f24dc21..442ed536 100644 --- a/infini_train/src/autograd/matmul.cc +++ b/infini_train/src/autograd/matmul.cc @@ -20,28 +20,43 @@ void Matmul::SetupContext(const std::vector> &input_tens const auto &input1 = input_tensors[0]; const auto &input2 = input_tensors[1]; const auto &output = output_tensors[0]; + // Cast saved tensors to forward compute dtype (output dtype) so backward + // computes in the same precision as forward, matching PyTorch's behavior. + + // FIXME: An extra cast (input1/input2 -> compute_dtype) is performed here because + // autocast runs before autograd. The correct approach is to adjust the ordering or + // integration of autocast and autograd so that autograd receives already-cast tensors, + // avoiding the redundant cast. + + // FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be + // determined by autocast, not derived from output->Dtype(). + auto compute_dtype = output->Dtype(); // grad_input1 = grad_output @ input2^T, so input2 is needed // grad_input2 = grad_output^T @ input1, so input1 is needed - bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0]; - bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1]; + bool need_grad_input1 = ctx_.needs_input_grad().size() > 0 && ctx_.needs_input_grad()[0]; + bool need_grad_input2 = ctx_.needs_input_grad().size() > 1 && ctx_.needs_input_grad()[1]; + + auto cast = [&](const std::shared_ptr &t) { + return t->Dtype() == compute_dtype ? t : std::make_shared(t->To(compute_dtype)); + }; - saved_tensors_ = {need_grad_input2 ? input1 : nullptr, need_grad_input1 ? input2 : nullptr}; + ctx_.SaveForBackward({need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr}); input1_dims_ = input1->Dims(); input2_dims_ = input2->Dims(); out_features_ = output->Dims()[0]; } std::vector> Matmul::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 2); - const auto &input1 = saved_tensors_[0]; - const auto &input2 = saved_tensors_[1]; + CHECK_EQ(ctx_.saved_tensors().size(), 2); + const auto &input1 = ctx_.saved_tensors()[0]; + const auto &input2 = ctx_.saved_tensors()[1]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; - CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Matmul::Backward"; - bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0]; - bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1]; + CHECK(!ctx_.needs_input_grad().empty()) << "needs_input_grad not populated in Matmul::Backward"; + bool need_grad_input1 = ctx_.needs_input_grad().size() > 0 && ctx_.needs_input_grad()[0]; + bool need_grad_input2 = ctx_.needs_input_grad().size() > 1 && ctx_.needs_input_grad()[1]; auto device = grad_output->GetDevice().type(); diff --git a/infini_train/src/autograd/normalization.cc b/infini_train/src/autograd/normalization.cc index 79a14abb..ab789119 100644 --- a/infini_train/src/autograd/normalization.cc +++ b/infini_train/src/autograd/normalization.cc @@ -18,26 +18,29 @@ std::vector> LayerNorm::Forward(const std::vector, std::shared_ptr, std::shared_ptr>>( {device, "LayerNormForward"}, input, weight, bias, eps_); - saved_tensors_ = {mean, rstd}; - return {output}; + return {output, mean, rstd}; } void LayerNorm::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { + const std::vector> &output_tensors) { + CHECK_EQ(output_tensors.size(), 3); const auto &input = input_tensors[0]; const auto &weight = input_tensors[1]; const auto &bias = input_tensors[2]; - saved_tensors_.insert(saved_tensors_.begin(), {input, weight, bias}); + const auto &mean = output_tensors[1]; + const auto &rstd = output_tensors[2]; + ctx_.MarkNonDifferentiable({mean, rstd}); + ctx_.SaveForBackward({input, weight, bias, mean, rstd}); } std::vector> LayerNorm::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 5); - const auto &input = saved_tensors_[0]; - const auto &weight = saved_tensors_[1]; - const auto &bias = saved_tensors_[2]; - const auto &mean = saved_tensors_[3]; - const auto &rstd = saved_tensors_[4]; - CHECK_EQ(grad_outputs.size(), 1); + CHECK_EQ(ctx_.saved_tensors().size(), 5); + const auto &input = ctx_.saved_tensors()[0]; + const auto &weight = ctx_.saved_tensors()[1]; + const auto &bias = ctx_.saved_tensors()[2]; + const auto &mean = ctx_.saved_tensors()[3]; + const auto &rstd = ctx_.saved_tensors()[4]; + CHECK_GE(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; auto device = input->GetDevice().type(); diff --git a/infini_train/src/autograd/outer.cc b/infini_train/src/autograd/outer.cc index 85a8c9ca..b53a8b5c 100644 --- a/infini_train/src/autograd/outer.cc +++ b/infini_train/src/autograd/outer.cc @@ -22,13 +22,13 @@ void Outer::SetupContext(const std::vector> &input_tenso const std::vector> &output_tensors) { const auto &input1 = input_tensors[0]; const auto &input2 = input_tensors[1]; - saved_tensors_ = {input1, input2}; + ctx_.SaveForBackward({input1, input2}); } std::vector> Outer::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 2); - const auto &input1 = saved_tensors_[0]; - const auto &input2 = saved_tensors_[1]; + CHECK_EQ(ctx_.saved_tensors().size(), 2); + const auto &input1 = ctx_.saved_tensors()[0]; + const auto &input2 = ctx_.saved_tensors()[1]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/reduction.cc b/infini_train/src/autograd/reduction.cc index 5a6e086f..9acf65ce 100644 --- a/infini_train/src/autograd/reduction.cc +++ b/infini_train/src/autograd/reduction.cc @@ -67,15 +67,15 @@ void Max::SetupContext(const std::vector> &input_tensors const std::vector> &output_tensors) { const auto &input = input_tensors[0]; const auto &output = output_tensors[0]; - saved_tensors_ = {input, output}; + ctx_.SaveForBackward({input, output}); } std::vector> Max::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); - CHECK_EQ(saved_tensors_.size(), 2); + CHECK_EQ(ctx_.saved_tensors().size(), 2); const auto &grad_output = grad_outputs[0]; - const auto &input = saved_tensors_[0]; - const auto &reduced = saved_tensors_[1]; + const auto &input = ctx_.saved_tensors()[0]; + const auto &reduced = ctx_.saved_tensors()[1]; auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MaxBackward"}, grad_output, input, reduced, @@ -94,15 +94,15 @@ void Min::SetupContext(const std::vector> &input_tensors const std::vector> &output_tensors) { const auto &input = input_tensors[0]; const auto &output = output_tensors[0]; - saved_tensors_ = {input, output}; + ctx_.SaveForBackward({input, output}); } std::vector> Min::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); - CHECK_EQ(saved_tensors_.size(), 2); + CHECK_EQ(ctx_.saved_tensors().size(), 2); const auto &grad_output = grad_outputs[0]; - const auto &input = saved_tensors_[0]; - const auto &reduced = saved_tensors_[1]; + const auto &input = ctx_.saved_tensors()[0]; + const auto &reduced = ctx_.saved_tensors()[1]; auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MinBackward"}, grad_output, input, reduced, diff --git a/infini_train/src/autograd/scatter.cc b/infini_train/src/autograd/scatter.cc index 472fd543..a9e7b2b9 100644 --- a/infini_train/src/autograd/scatter.cc +++ b/infini_train/src/autograd/scatter.cc @@ -18,13 +18,13 @@ std::vector> Scatter::Forward(const std::vector> &input_tensors, const std::vector> &) { - saved_tensors_ = {input_tensors[1]}; + ctx_.SaveForBackward({input_tensors[1]}); } std::vector> Scatter::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; - const auto &indices = saved_tensors_[0]; + const auto &indices = ctx_.saved_tensors()[0]; auto device = grad_output->GetDevice().type(); auto grad_values = Dispatcher::Instance().Call>({device, "ScatterBackward"}, grad_output, indices); diff --git a/infini_train/src/autograd/softmax.cc b/infini_train/src/autograd/softmax.cc index 39569a8c..76b30abe 100644 --- a/infini_train/src/autograd/softmax.cc +++ b/infini_train/src/autograd/softmax.cc @@ -17,12 +17,12 @@ std::vector> Softmax::Forward(const std::vector> &, const std::vector> &output_tensors) { const auto &output = output_tensors[0]; - saved_tensors_ = {output}; + ctx_.SaveForBackward({output}); } std::vector> Softmax::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &output = saved_tensors_[0]; + CHECK_EQ(ctx_.saved_tensors().size(), 1); + const auto &output = ctx_.saved_tensors()[0]; CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/sparse.cc b/infini_train/src/autograd/sparse.cc index 93315b4f..98412f1f 100644 --- a/infini_train/src/autograd/sparse.cc +++ b/infini_train/src/autograd/sparse.cc @@ -20,12 +20,12 @@ void Embedding::SetupContext(const std::vector> &input_t const auto &input = input_tensors[0]; const auto &weight = input_tensors[1]; weight_dims_ = weight->Dims(); - saved_tensors_ = {input}; + ctx_.SaveForBackward({input}); } std::vector> Embedding::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); - const auto &input = saved_tensors_[0]; + const auto &input = ctx_.saved_tensors()[0]; const auto &grad_output = grad_outputs[0]; auto device = input->GetDevice().type(); diff --git a/infini_train/src/autograd/transform.cc b/infini_train/src/autograd/transform.cc index e38d5616..66f64fd3 100644 --- a/infini_train/src/autograd/transform.cc +++ b/infini_train/src/autograd/transform.cc @@ -166,12 +166,12 @@ void Slice::SetupContext(const std::vector> &input_tenso const std::vector> &) { // FIXME(dcj): only input's dim need to be saved const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + ctx_.SaveForBackward({input}); } std::vector> Slice::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(ctx_.saved_tensors().size(), 1); + const auto &input = ctx_.saved_tensors()[0]; const auto &grad_output = grad_outputs[0]; auto device = input->GetDevice().type(); diff --git a/infini_train/src/nn/modules/normalization.cc b/infini_train/src/nn/modules/normalization.cc index 4a11cc1b..388b04de 100644 --- a/infini_train/src/nn/modules/normalization.cc +++ b/infini_train/src/nn/modules/normalization.cc @@ -22,8 +22,9 @@ LayerNorm::LayerNorm(const std::vector &normalized_shape, float eps, De } std::vector> LayerNorm::Forward(const std::vector> &input_tensors) { - return std::make_shared(eps_)->Apply( + auto outputs = std::make_shared(eps_)->Apply( {input_tensors[0], parameters_[kParamWeightName], parameters_[kParamBiasName]}); + return {outputs[0]}; } void LayerNorm::ResetParameters() { diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index fc01007b..205cb773 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -535,7 +535,7 @@ VocabParallelCrossEntropy::Forward(const std::vector> &i } // 8. Save for backward - saved_tensors_ = {softmax_local, target_mask, masked_target, valid_mask_local}; + ctx_.SaveForBackward({softmax_local, target_mask, masked_target, valid_mask_local}); return {loss}; } @@ -545,10 +545,10 @@ VocabParallelCrossEntropy::Backward(const std::vector> & CHECK_EQ(grad_outputs.size(), 1); auto grad_output = grad_outputs[0]; - auto softmax_local = saved_tensors_[0]; - auto target_mask = std::make_shared(saved_tensors_[1]->To(softmax_local->Dtype())); - auto masked_target = saved_tensors_[2]; - auto valid_mask_local = saved_tensors_[3]; + auto softmax_local = ctx_.saved_tensors()[0]; + auto target_mask = std::make_shared(ctx_.saved_tensors()[1]->To(softmax_local->Dtype())); + auto masked_target = ctx_.saved_tensors()[2]; + auto valid_mask_local = ctx_.saved_tensors()[3]; auto device = grad_output->GetDevice().type(); auto grad_input = Dispatcher::Instance().Call>( diff --git a/tests/autograd/test_autograd.cc b/tests/autograd/test_autograd.cc index 1d6d129a..f59dd066 100644 --- a/tests/autograd/test_autograd.cc +++ b/tests/autograd/test_autograd.cc @@ -27,6 +27,131 @@ using namespace infini_train; class AutogradForwardTest : public infini_train::test::InfiniTrainTest {}; class AutogradBackwardTest : public infini_train::test::InfiniTrainTest {}; +class SaveOutputForBackwardFunction : public autograd::Function { +public: + static constexpr char kType[] = "SaveOutputForBackwardFunction"; + + SaveOutputForBackwardFunction() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + const auto &input = input_tensors[0]; + auto output = std::make_shared(input->Dims(), input->Dtype(), input->GetDevice()); + output->CopyFrom(input); + return {output}; + } + + void SetupContext(const std::vector> &, + const std::vector> &output_tensors) override { + ctx_.SaveForBackward({output_tensors[0]}); + } + + std::vector> Backward(const std::vector> &grad_outputs) override { + return {grad_outputs[0]}; + } + + const std::shared_ptr &saved_tensor() const { return ctx_.saved_tensors()[0]; } +}; + +class NeedsInputGradFunction : public autograd::Function { +public: + static constexpr char kType[] = "NeedsInputGradFunction"; + + NeedsInputGradFunction() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + const auto &input = input_tensors[0]; + auto output = std::make_shared(input->Dims(), input->Dtype(), input->GetDevice()); + output->CopyFrom(input); + return {output}; + } + + void SetupContext(const std::vector> &input_tensors, + const std::vector> &) override { + observed_needs_input_grad_ = ctx_.needs_input_grad(); + ctx_.SaveForBackward({input_tensors[0]}); + } + + std::vector> Backward(const std::vector> &grad_outputs) override { + return {grad_outputs[0], nullptr}; + } + + const std::vector &observed_needs_input_grad() const { return observed_needs_input_grad_; } + const std::shared_ptr &saved_tensor() const { return ctx_.saved_tensors()[0]; } + +private: + std::vector observed_needs_input_grad_; +}; + +class MarkNonDifferentiableFunction : public autograd::Function { +public: + static constexpr char kType[] = "MarkNonDifferentiableFunction"; + + MarkNonDifferentiableFunction() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + const auto &input = input_tensors[0]; + auto differentiable = std::make_shared(input->Dims(), input->Dtype(), input->GetDevice()); + auto non_differentiable = std::make_shared(input->Dims(), input->Dtype(), input->GetDevice()); + differentiable->CopyFrom(input); + non_differentiable->CopyFrom(input); + return {differentiable, non_differentiable}; + } + + void SetupContext(const std::vector> &, + const std::vector> &output_tensors) override { + ctx_.MarkNonDifferentiable({output_tensors[1]}); + } + + std::vector> Backward(const std::vector> &grad_outputs) override { + return {grad_outputs[0]}; + } +}; + +TEST_P(AutogradForwardTest, SavedOutputIsPackedWithoutAutogradMeta) { + auto input = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice(), true); + input->Fill(1.0f); + + auto fn = std::make_shared(); + auto outputs = fn->Apply({input}); + + ASSERT_EQ(outputs.size(), 1); + ASSERT_NE(fn->saved_tensor(), nullptr); + EXPECT_NE(fn->saved_tensor().get(), outputs[0].get()); + EXPECT_EQ(fn->saved_tensor()->DataPtr(), outputs[0]->DataPtr()); + EXPECT_FALSE(fn->saved_tensor()->requires_grad()); + EXPECT_EQ(fn->saved_tensor()->grad_fn(), nullptr); +} + +TEST_P(AutogradForwardTest, FunctionCtxNeedsInputGradAndSaveForBackward) { + auto requires_grad_input = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice(), true); + auto no_grad_input = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice(), false); + requires_grad_input->Fill(1.0f); + no_grad_input->Fill(2.0f); + + auto fn = std::make_shared(); + auto outputs = fn->Apply({requires_grad_input, no_grad_input}); + + ASSERT_EQ(outputs.size(), 1); + ASSERT_EQ(fn->observed_needs_input_grad().size(), 2); + EXPECT_TRUE(fn->observed_needs_input_grad()[0]); + EXPECT_FALSE(fn->observed_needs_input_grad()[1]); + EXPECT_EQ(fn->saved_tensor().get(), requires_grad_input.get()); +} + +TEST_P(AutogradForwardTest, MarkNonDifferentiableOutput) { + auto input = std::make_shared(std::vector{2, 2}, DataType::kFLOAT32, GetDevice(), true); + input->Fill(1.0f); + + auto outputs = std::make_shared()->Apply({input}); + + ASSERT_EQ(outputs.size(), 2); + EXPECT_TRUE(outputs[0]->requires_grad()); + EXPECT_NE(outputs[0]->grad_fn(), nullptr); + EXPECT_FALSE(outputs[1]->requires_grad()); + EXPECT_EQ(outputs[1]->grad_fn(), nullptr); + EXPECT_TRUE(outputs[1]->is_leaf()); +} + TEST_P(AutogradForwardTest, AddForward) { auto a = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32, GetDevice(), true); a->Fill(1.0f); @@ -188,7 +313,11 @@ TEST_P(AutogradForwardTest, LayerNormForward) { auto bias = std::make_shared(std::vector{4}, DataType::kFLOAT32, GetDevice(), true); bias->Fill(0.0f); auto result = std::make_shared(1e-5f)->Apply({a, weight, bias}); - EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result.size(), 3); + EXPECT_FALSE(result[1]->requires_grad()); + EXPECT_EQ(result[1]->grad_fn(), nullptr); + EXPECT_FALSE(result[2]->requires_grad()); + EXPECT_EQ(result[2]->grad_fn(), nullptr); } TEST_P(AutogradForwardTest, LinearForward) { diff --git a/tests/autograd/test_autograd_normalization_forward.cc b/tests/autograd/test_autograd_normalization_forward.cc index 076c97c3..ebb876d8 100644 --- a/tests/autograd/test_autograd_normalization_forward.cc +++ b/tests/autograd/test_autograd_normalization_forward.cc @@ -21,7 +21,11 @@ TEST_P(AutogradNormalizationForwardTest, LayerNormForward) { bias->Fill(0.0f); auto layernorm_fn = std::make_shared(1e-5f); auto result = layernorm_fn->Apply({a, weight, bias}); - EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result.size(), 3); + EXPECT_FALSE(result[1]->requires_grad()); + EXPECT_EQ(result[1]->grad_fn(), nullptr); + EXPECT_FALSE(result[2]->requires_grad()); + EXPECT_EQ(result[2]->grad_fn(), nullptr); } TEST_P(AutogradNormalizationForwardTest, LayerNormZeroBias) { @@ -33,7 +37,11 @@ TEST_P(AutogradNormalizationForwardTest, LayerNormZeroBias) { bias->Fill(0.0f); auto layernorm_fn = std::make_shared(1e-5f); auto result = layernorm_fn->Apply({a, weight, bias}); - EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result.size(), 3); + EXPECT_FALSE(result[1]->requires_grad()); + EXPECT_EQ(result[1]->grad_fn(), nullptr); + EXPECT_FALSE(result[2]->requires_grad()); + EXPECT_EQ(result[2]->grad_fn(), nullptr); } TEST_P(AutogradNormalizationForwardTest, LayerNormThreeDim) { @@ -45,7 +53,11 @@ TEST_P(AutogradNormalizationForwardTest, LayerNormThreeDim) { bias->Fill(0.0f); auto layernorm_fn = std::make_shared(1e-5f); auto result = layernorm_fn->Apply({a, weight, bias}); - EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result.size(), 3); + EXPECT_FALSE(result[1]->requires_grad()); + EXPECT_EQ(result[1]->grad_fn(), nullptr); + EXPECT_FALSE(result[2]->requires_grad()); + EXPECT_EQ(result[2]->grad_fn(), nullptr); EXPECT_EQ(result[0]->Dims(), (std::vector{2, 1, 4})); }