diff --git a/source/op/pt/tabulate_multi_device.cc b/source/op/pt/tabulate_multi_device.cc index 530e9ddf4d..cede1d03d9 100644 --- a/source/op/pt/tabulate_multi_device.cc +++ b/source/op/pt/tabulate_multi_device.cc @@ -602,8 +602,11 @@ void TabulateFusionSeRGradGradForward(const torch::Tensor& table_tensor, } } -class TabulateFusionSeAOp - : public torch::autograd::Function { +class TabulateFusionSeAGradOp + : public torch::autograd::Function { + private: + std::string device; + public: static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, @@ -611,14 +614,16 @@ class TabulateFusionSeAOp const torch::Tensor& table_info_tensor, const torch::Tensor& em_x_tensor, const torch::Tensor& em_tensor, - int64_t last_layer_size) { + const torch::Tensor& dy_tensor, + const torch::Tensor& descriptor_tensor) { bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; if (type_flag) { return forward_t(ctx, table_tensor, table_info_tensor, - em_x_tensor, em_tensor, last_layer_size); + em_x_tensor, em_tensor, dy_tensor, + descriptor_tensor); } else { return forward_t(ctx, table_tensor, table_info_tensor, em_x_tensor, - em_tensor, last_layer_size); + em_tensor, dy_tensor, descriptor_tensor); } } @@ -629,26 +634,28 @@ class TabulateFusionSeAOp const torch::Tensor& table_info_tensor, const torch::Tensor& em_x_tensor, const torch::Tensor& em_tensor, - int64_t last_layer_size) { - // allocate output tensors - auto options = torch::TensorOptions() - .dtype(table_tensor.dtype()) - .device(table_tensor.device()); - torch::Tensor descriptor_tensor = - torch::empty({em_tensor.size(0), 4, last_layer_size}, options); + const torch::Tensor& dy_tensor, + const torch::Tensor& descriptor_tensor) { + // Allocate output tensors + torch::Tensor dy_dem_x_tensor = torch::zeros_like(em_x_tensor); + torch::Tensor dy_dem_tensor = torch::zeros_like(em_tensor); + torch::Tensor dy_dtwo_tensor = at::Tensor(); // compute - TabulateFusionSeAForward(table_tensor, table_info_tensor, - em_x_tensor, em_tensor, at::Tensor(), - last_layer_size, descriptor_tensor); + TabulateFusionSeAGradForward( + table_tensor, table_info_tensor, em_x_tensor, em_tensor, at::Tensor(), + dy_tensor, descriptor_tensor, dy_dem_x_tensor, dy_dem_tensor, + dy_dtwo_tensor); // save data ctx->save_for_backward({table_tensor, table_info_tensor, em_x_tensor, em_tensor, descriptor_tensor}); - return {descriptor_tensor}; + + return torch::autograd::variable_list{dy_dem_x_tensor, dy_dem_tensor}; } static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) { + // load data torch::autograd::variable_list saved_variables = ctx->get_saved_variables(); torch::Tensor table_tensor = saved_variables[0]; bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; @@ -669,31 +676,32 @@ class TabulateFusionSeAOp torch::Tensor table_info_tensor = saved_variables[1]; torch::Tensor em_x_tensor = saved_variables[2]; torch::Tensor em_tensor = saved_variables[3]; - torch::Tensor two_embed_tensor = at::Tensor(); torch::Tensor descriptor_tensor = saved_variables[4]; - // ensure the gradient output is contiguous - torch::Tensor dy_tensor = grad_output[0].contiguous(); + bool is_sorted = true; + + torch::Tensor dz_dy_dem_x_tensor = grad_output[0].defined() + ? grad_output[0].contiguous() + : torch::zeros_like(em_x_tensor); + torch::Tensor dz_dy_dem_tensor = grad_output[1].defined() + ? grad_output[1].contiguous() + : torch::zeros_like(em_tensor); // allocate output tensors - torch::Tensor dy_dem_x_tensor = torch::zeros_like(em_x_tensor); - torch::Tensor dy_dem_tensor = torch::zeros_like(em_tensor); - torch::Tensor dy_dtwo_tensor = at::Tensor(); + torch::Tensor dz_dy_tensor = torch::empty_like(descriptor_tensor); // compute - TabulateFusionSeAGradForward( - table_tensor, table_info_tensor, em_x_tensor, em_tensor, - two_embed_tensor, dy_tensor, descriptor_tensor, dy_dem_x_tensor, - dy_dem_tensor, dy_dtwo_tensor); + TabulateFusionSeAGradGradForward( + table_tensor, table_info_tensor, em_x_tensor, em_tensor, at::Tensor(), + dz_dy_dem_x_tensor, dz_dy_dem_tensor, at::Tensor(), descriptor_tensor, + is_sorted, dz_dy_tensor); - return {at::Tensor(), at::Tensor(), dy_dem_x_tensor, dy_dem_tensor, - at::Tensor()}; + return torch::autograd::variable_list{at::Tensor(), at::Tensor(), + at::Tensor(), at::Tensor(), + dz_dy_tensor, at::Tensor()}; } }; -class TabulateFusionSeAGradOp - : public torch::autograd::Function { - private: - std::string device; - +class TabulateFusionSeAGradGradOp + : public torch::autograd::Function { public: static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, @@ -701,16 +709,19 @@ class TabulateFusionSeAGradOp const torch::Tensor& table_info_tensor, const torch::Tensor& em_x_tensor, const torch::Tensor& em_tensor, - const torch::Tensor& dy_tensor, - const torch::Tensor& descriptor_tensor) { + const torch::Tensor& dz_dy_dem_x_tensor, + const torch::Tensor& dz_dy_dem_tensor, + const torch::Tensor& descriptor_tensor, + bool is_sorted) { bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; if (type_flag) { return forward_t(ctx, table_tensor, table_info_tensor, - em_x_tensor, em_tensor, dy_tensor, - descriptor_tensor); + em_x_tensor, em_tensor, dz_dy_dem_x_tensor, + dz_dy_dem_tensor, descriptor_tensor, is_sorted); } else { return forward_t(ctx, table_tensor, table_info_tensor, em_x_tensor, - em_tensor, dy_tensor, descriptor_tensor); + em_tensor, dz_dy_dem_x_tensor, dz_dy_dem_tensor, + descriptor_tensor, is_sorted); } } @@ -721,28 +732,69 @@ class TabulateFusionSeAGradOp const torch::Tensor& table_info_tensor, const torch::Tensor& em_x_tensor, const torch::Tensor& em_tensor, - const torch::Tensor& dy_tensor, - const torch::Tensor& descriptor_tensor) { - // Allocate output tensors - torch::Tensor dy_dem_x_tensor = torch::zeros_like(em_x_tensor); - torch::Tensor dy_dem_tensor = torch::zeros_like(em_tensor); - torch::Tensor dy_dtwo_tensor = at::Tensor(); + const torch::Tensor& dz_dy_dem_x_tensor, + const torch::Tensor& dz_dy_dem_tensor, + const torch::Tensor& descriptor_tensor, + bool is_sorted) { + // Allocate output tensor + torch::Tensor dz_dy_tensor = torch::empty_like(descriptor_tensor); // compute - TabulateFusionSeAGradForward( + TabulateFusionSeAGradGradForward( table_tensor, table_info_tensor, em_x_tensor, em_tensor, at::Tensor(), - dy_tensor, descriptor_tensor, dy_dem_x_tensor, dy_dem_tensor, - dy_dtwo_tensor); + dz_dy_dem_x_tensor, dz_dy_dem_tensor, at::Tensor(), descriptor_tensor, + is_sorted, dz_dy_tensor); + + return torch::autograd::variable_list{dz_dy_tensor}; + } +}; + +class TabulateFusionSeAOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& table_tensor, + const torch::Tensor& table_info_tensor, + const torch::Tensor& em_x_tensor, + const torch::Tensor& em_tensor, + int64_t last_layer_size) { + bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; + if (type_flag) { + return forward_t(ctx, table_tensor, table_info_tensor, + em_x_tensor, em_tensor, last_layer_size); + } else { + return forward_t(ctx, table_tensor, table_info_tensor, em_x_tensor, + em_tensor, last_layer_size); + } + } + + template + static torch::autograd::variable_list forward_t( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& table_tensor, + const torch::Tensor& table_info_tensor, + const torch::Tensor& em_x_tensor, + const torch::Tensor& em_tensor, + int64_t last_layer_size) { + // allocate output tensors + auto options = torch::TensorOptions() + .dtype(table_tensor.dtype()) + .device(table_tensor.device()); + torch::Tensor descriptor_tensor = + torch::empty({em_tensor.size(0), 4, last_layer_size}, options); + // compute + TabulateFusionSeAForward(table_tensor, table_info_tensor, + em_x_tensor, em_tensor, at::Tensor(), + last_layer_size, descriptor_tensor); // save data ctx->save_for_backward({table_tensor, table_info_tensor, em_x_tensor, em_tensor, descriptor_tensor}); - - return torch::autograd::variable_list{dy_dem_x_tensor, dy_dem_tensor}; + return {descriptor_tensor}; } static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) { - // load data torch::autograd::variable_list saved_variables = ctx->get_saved_variables(); torch::Tensor table_tensor = saved_variables[0]; bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; @@ -765,26 +817,20 @@ class TabulateFusionSeAGradOp torch::Tensor em_tensor = saved_variables[3]; torch::Tensor descriptor_tensor = saved_variables[4]; - bool is_sorted = true; - - torch::Tensor dz_dy_dem_x_tensor = grad_output[0].contiguous(); - torch::Tensor dz_dy_dem_tensor = grad_output[1].contiguous(); - // allocate output tensors - torch::Tensor dz_dy_tensor = torch::empty_like(descriptor_tensor); - // compute - TabulateFusionSeAGradGradForward( - table_tensor, table_info_tensor, em_x_tensor, em_tensor, at::Tensor(), - dz_dy_dem_x_tensor, dz_dy_dem_tensor, at::Tensor(), descriptor_tensor, - is_sorted, dz_dy_tensor); + // ensure the gradient output is contiguous + torch::Tensor dy_tensor = grad_output[0].contiguous(); + torch::autograd::variable_list dy_dem_tensors = + TabulateFusionSeAGradOp::apply(table_tensor, table_info_tensor, + em_x_tensor, em_tensor, dy_tensor, + descriptor_tensor); - return torch::autograd::variable_list{at::Tensor(), at::Tensor(), - at::Tensor(), at::Tensor(), - dz_dy_tensor, at::Tensor()}; + return {at::Tensor(), at::Tensor(), dy_dem_tensors[0], dy_dem_tensors[1], + at::Tensor()}; } }; -class TabulateFusionSeAGradGradOp - : public torch::autograd::Function { +class TabulateFusionSeAttenGradOp + : public torch::autograd::Function { public: static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, @@ -792,18 +838,18 @@ class TabulateFusionSeAGradGradOp const torch::Tensor& table_info_tensor, const torch::Tensor& em_x_tensor, const torch::Tensor& em_tensor, - const torch::Tensor& dz_dy_dem_x_tensor, - const torch::Tensor& dz_dy_dem_tensor, + const torch::Tensor& two_embed_tensor, + const torch::Tensor& dy_tensor, const torch::Tensor& descriptor_tensor, bool is_sorted) { bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; if (type_flag) { return forward_t(ctx, table_tensor, table_info_tensor, - em_x_tensor, em_tensor, dz_dy_dem_x_tensor, - dz_dy_dem_tensor, descriptor_tensor, is_sorted); + em_x_tensor, em_tensor, two_embed_tensor, + dy_tensor, descriptor_tensor, is_sorted); } else { return forward_t(ctx, table_tensor, table_info_tensor, em_x_tensor, - em_tensor, dz_dy_dem_x_tensor, dz_dy_dem_tensor, + em_tensor, two_embed_tensor, dy_tensor, descriptor_tensor, is_sorted); } } @@ -815,19 +861,70 @@ class TabulateFusionSeAGradGradOp const torch::Tensor& table_info_tensor, const torch::Tensor& em_x_tensor, const torch::Tensor& em_tensor, - const torch::Tensor& dz_dy_dem_x_tensor, - const torch::Tensor& dz_dy_dem_tensor, + const torch::Tensor& two_embed_tensor, + const torch::Tensor& dy_tensor, const torch::Tensor& descriptor_tensor, bool is_sorted) { - // Allocate output tensor + torch::Tensor dy_dem_x_tensor = torch::zeros_like(em_x_tensor); + torch::Tensor dy_dem_tensor = torch::zeros_like(em_tensor); + torch::Tensor dy_dtwo_tensor = torch::zeros_like(two_embed_tensor); + TabulateFusionSeAGradForward( + table_tensor, table_info_tensor, em_x_tensor, em_tensor, + two_embed_tensor, dy_tensor, descriptor_tensor, dy_dem_x_tensor, + dy_dem_tensor, dy_dtwo_tensor); + + ctx->save_for_backward({table_tensor, table_info_tensor, em_x_tensor, + em_tensor, two_embed_tensor, descriptor_tensor}); + ctx->saved_data["is_sorted"] = is_sorted; + + return torch::autograd::variable_list{dy_dem_x_tensor, dy_dem_tensor, + dy_dtwo_tensor}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + torch::autograd::variable_list saved_variables = ctx->get_saved_variables(); + torch::Tensor table_tensor = saved_variables[0]; + bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; + if (type_flag) { + return backward_t(ctx, grad_output); + } else { + return backward_t(ctx, grad_output); + } + } + + template + static torch::autograd::variable_list backward_t( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + torch::autograd::variable_list saved_variables = ctx->get_saved_variables(); + torch::Tensor table_tensor = saved_variables[0]; + torch::Tensor table_info_tensor = saved_variables[1]; + torch::Tensor em_x_tensor = saved_variables[2]; + torch::Tensor em_tensor = saved_variables[3]; + torch::Tensor two_embed_tensor = saved_variables[4]; + torch::Tensor descriptor_tensor = saved_variables[5]; + bool is_sorted = ctx->saved_data["is_sorted"].toBool(); + + torch::Tensor dz_dy_dem_x_tensor = grad_output[0].defined() + ? grad_output[0].contiguous() + : torch::zeros_like(em_x_tensor); + torch::Tensor dz_dy_dem_tensor = grad_output[1].defined() + ? grad_output[1].contiguous() + : torch::zeros_like(em_tensor); + torch::Tensor dz_dy_dtwo_tensor = grad_output[2].defined() + ? grad_output[2].contiguous() + : torch::zeros_like(two_embed_tensor); torch::Tensor dz_dy_tensor = torch::empty_like(descriptor_tensor); - // compute TabulateFusionSeAGradGradForward( - table_tensor, table_info_tensor, em_x_tensor, em_tensor, at::Tensor(), - dz_dy_dem_x_tensor, dz_dy_dem_tensor, at::Tensor(), descriptor_tensor, - is_sorted, dz_dy_tensor); + table_tensor, table_info_tensor, em_x_tensor, em_tensor, + two_embed_tensor, dz_dy_dem_x_tensor, dz_dy_dem_tensor, + dz_dy_dtwo_tensor, descriptor_tensor, is_sorted, dz_dy_tensor); - return torch::autograd::variable_list{dz_dy_tensor}; + return torch::autograd::variable_list{ + at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor(), + at::Tensor(), dz_dy_tensor, at::Tensor(), at::Tensor()}; } }; @@ -878,6 +975,7 @@ class TabulateFusionSeAttenOp // save data ctx->save_for_backward({table_tensor, table_info_tensor, em_x_tensor, em_tensor, two_embed_tensor, descriptor_tensor}); + ctx->saved_data["is_sorted"] = is_sorted; return {descriptor_tensor}; } @@ -906,20 +1004,101 @@ class TabulateFusionSeAttenOp torch::Tensor em_tensor = saved_variables[3]; torch::Tensor two_embed_tensor = saved_variables[4]; torch::Tensor descriptor_tensor = saved_variables[5]; + bool is_sorted = ctx->saved_data["is_sorted"].toBool(); torch::Tensor dy_tensor = grad_output[0].contiguous(); - // allocate output tensors + torch::autograd::variable_list dy_dem_tensors = + TabulateFusionSeAttenGradOp::apply( + table_tensor, table_info_tensor, em_x_tensor, em_tensor, + two_embed_tensor, dy_tensor, descriptor_tensor, is_sorted); + + return {at::Tensor(), at::Tensor(), dy_dem_tensors[0], + dy_dem_tensors[1], dy_dem_tensors[2], at::Tensor(), + at::Tensor()}; + } +}; + +class TabulateFusionSeTGradOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& table_tensor, + const torch::Tensor& table_info_tensor, + const torch::Tensor& em_x_tensor, + const torch::Tensor& em_tensor, + const torch::Tensor& dy_tensor, + const torch::Tensor& descriptor_tensor) { + bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; + if (type_flag) { + return forward_t(ctx, table_tensor, table_info_tensor, + em_x_tensor, em_tensor, dy_tensor, + descriptor_tensor); + } else { + return forward_t(ctx, table_tensor, table_info_tensor, em_x_tensor, + em_tensor, dy_tensor, descriptor_tensor); + } + } + + template + static torch::autograd::variable_list forward_t( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& table_tensor, + const torch::Tensor& table_info_tensor, + const torch::Tensor& em_x_tensor, + const torch::Tensor& em_tensor, + const torch::Tensor& dy_tensor, + const torch::Tensor& descriptor_tensor) { torch::Tensor dy_dem_x_tensor = torch::zeros_like(em_x_tensor); torch::Tensor dy_dem_tensor = torch::zeros_like(em_tensor); - torch::Tensor dy_dtwo_tensor = torch::zeros_like(two_embed_tensor); - // compute - TabulateFusionSeAGradForward( + TabulateFusionSeTGradForward( + table_tensor, table_info_tensor, em_x_tensor, em_tensor, dy_tensor, + descriptor_tensor, dy_dem_x_tensor, dy_dem_tensor); + + ctx->save_for_backward({table_tensor, table_info_tensor, em_x_tensor, + em_tensor, descriptor_tensor}); + + return torch::autograd::variable_list{dy_dem_x_tensor, dy_dem_tensor}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + torch::autograd::variable_list saved_variables = ctx->get_saved_variables(); + torch::Tensor table_tensor = saved_variables[0]; + bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; + if (type_flag) { + return backward_t(ctx, grad_output); + } else { + return backward_t(ctx, grad_output); + } + } + + template + static torch::autograd::variable_list backward_t( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + torch::autograd::variable_list saved_variables = ctx->get_saved_variables(); + torch::Tensor table_tensor = saved_variables[0]; + torch::Tensor table_info_tensor = saved_variables[1]; + torch::Tensor em_x_tensor = saved_variables[2]; + torch::Tensor em_tensor = saved_variables[3]; + torch::Tensor descriptor_tensor = saved_variables[4]; + + torch::Tensor dz_dy_dem_x_tensor = grad_output[0].defined() + ? grad_output[0].contiguous() + : torch::zeros_like(em_x_tensor); + torch::Tensor dz_dy_dem_tensor = grad_output[1].defined() + ? grad_output[1].contiguous() + : torch::zeros_like(em_tensor); + torch::Tensor dz_dy_tensor = torch::empty_like(descriptor_tensor); + TabulateFusionSeTGradGradForward( table_tensor, table_info_tensor, em_x_tensor, em_tensor, - two_embed_tensor, dy_tensor, descriptor_tensor, dy_dem_x_tensor, - dy_dem_tensor, dy_dtwo_tensor); + dz_dy_dem_x_tensor, dz_dy_dem_tensor, descriptor_tensor, dz_dy_tensor); - return {at::Tensor(), at::Tensor(), dy_dem_x_tensor, dy_dem_tensor, - dy_dtwo_tensor, at::Tensor(), at::Tensor()}; + return torch::autograd::variable_list{at::Tensor(), at::Tensor(), + at::Tensor(), at::Tensor(), + dz_dy_tensor, at::Tensor()}; } }; @@ -993,19 +1172,91 @@ class TabulateFusionSeTOp torch::Tensor descriptor_tensor = saved_variables[4]; torch::Tensor dy_tensor = grad_output[0].contiguous(); - // allocate output tensors - torch::Tensor dy_dem_x_tensor = torch::zeros_like(em_x_tensor); - torch::Tensor dy_dem_tensor = torch::zeros_like(em_tensor); - // compute - TabulateFusionSeTGradForward( - table_tensor, table_info_tensor, em_x_tensor, em_tensor, dy_tensor, - descriptor_tensor, dy_dem_x_tensor, dy_dem_tensor); + torch::autograd::variable_list dy_dem_tensors = + TabulateFusionSeTGradOp::apply(table_tensor, table_info_tensor, + em_x_tensor, em_tensor, dy_tensor, + descriptor_tensor); - return {at::Tensor(), at::Tensor(), dy_dem_x_tensor, dy_dem_tensor, + return {at::Tensor(), at::Tensor(), dy_dem_tensors[0], dy_dem_tensors[1], at::Tensor()}; } }; +class TabulateFusionSeRGradOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& table_tensor, + const torch::Tensor& table_info_tensor, + const torch::Tensor& em_tensor, + const torch::Tensor& dy_tensor, + const torch::Tensor& descriptor_tensor) { + bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; + if (type_flag) { + return forward_t(ctx, table_tensor, table_info_tensor, em_tensor, + dy_tensor, descriptor_tensor); + } else { + return forward_t(ctx, table_tensor, table_info_tensor, em_tensor, + dy_tensor, descriptor_tensor); + } + } + + template + static torch::autograd::variable_list forward_t( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& table_tensor, + const torch::Tensor& table_info_tensor, + const torch::Tensor& em_tensor, + const torch::Tensor& dy_tensor, + const torch::Tensor& descriptor_tensor) { + torch::Tensor dy_dem_tensor = torch::zeros_like(em_tensor); + TabulateFusionSeRGradForward(table_tensor, table_info_tensor, + em_tensor, dy_tensor, + descriptor_tensor, dy_dem_tensor); + + ctx->save_for_backward( + {table_tensor, table_info_tensor, em_tensor, descriptor_tensor}); + + return torch::autograd::variable_list{dy_dem_tensor}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + torch::autograd::variable_list saved_variables = ctx->get_saved_variables(); + torch::Tensor table_tensor = saved_variables[0]; + bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; + if (type_flag) { + return backward_t(ctx, grad_output); + } else { + return backward_t(ctx, grad_output); + } + } + + template + static torch::autograd::variable_list backward_t( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + torch::autograd::variable_list saved_variables = ctx->get_saved_variables(); + torch::Tensor table_tensor = saved_variables[0]; + torch::Tensor table_info_tensor = saved_variables[1]; + torch::Tensor em_tensor = saved_variables[2]; + torch::Tensor descriptor_tensor = saved_variables[3]; + + torch::Tensor dz_dy_dem_tensor = grad_output[0].defined() + ? grad_output[0].contiguous() + : torch::zeros_like(em_tensor); + torch::Tensor dz_dy_tensor = torch::empty_like(descriptor_tensor); + TabulateFusionSeRGradGradForward(table_tensor, table_info_tensor, + em_tensor, dz_dy_dem_tensor, + descriptor_tensor, dz_dy_tensor); + + return torch::autograd::variable_list{ + at::Tensor(), at::Tensor(), at::Tensor(), dz_dy_tensor, at::Tensor()}; + } +}; + class TabulateFusionSeROp : public torch::autograd::Function { public: @@ -1072,14 +1323,93 @@ class TabulateFusionSeROp torch::Tensor descriptor_tensor = saved_variables[3]; torch::Tensor dy_tensor = grad_output[0].contiguous(); - // allocate output tensors - torch::Tensor dy_dem_tensor = torch::zeros_like(em_tensor); - // compute - TabulateFusionSeRGradForward(table_tensor, table_info_tensor, - em_tensor, dy_tensor, - descriptor_tensor, dy_dem_tensor); + torch::autograd::variable_list dy_dem_tensors = + TabulateFusionSeRGradOp::apply(table_tensor, table_info_tensor, + em_tensor, dy_tensor, descriptor_tensor); + + return {at::Tensor(), at::Tensor(), dy_dem_tensors[0], at::Tensor()}; + } +}; + +class TabulateFusionSeTTebdGradOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& table_tensor, + const torch::Tensor& table_info_tensor, + const torch::Tensor& em_x_tensor, + const torch::Tensor& em_tensor, + const torch::Tensor& dy_tensor, + const torch::Tensor& descriptor_tensor) { + bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; + if (type_flag) { + return forward_t(ctx, table_tensor, table_info_tensor, + em_x_tensor, em_tensor, dy_tensor, + descriptor_tensor); + } else { + return forward_t(ctx, table_tensor, table_info_tensor, em_x_tensor, + em_tensor, dy_tensor, descriptor_tensor); + } + } + + template + static torch::autograd::variable_list forward_t( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& table_tensor, + const torch::Tensor& table_info_tensor, + const torch::Tensor& em_x_tensor, + const torch::Tensor& em_tensor, + const torch::Tensor& dy_tensor, + const torch::Tensor& descriptor_tensor) { + torch::Tensor dy_dem_x_tensor = torch::zeros_like(em_x_tensor); + TabulateFusionSeTTebdGradForward( + table_tensor, table_info_tensor, em_x_tensor, em_tensor, dy_tensor, + descriptor_tensor, dy_dem_x_tensor); + + ctx->save_for_backward({table_tensor, table_info_tensor, em_x_tensor, + em_tensor, descriptor_tensor}); + + return torch::autograd::variable_list{dy_dem_x_tensor}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + torch::autograd::variable_list saved_variables = ctx->get_saved_variables(); + torch::Tensor table_tensor = saved_variables[0]; + bool type_flag = (table_tensor.dtype() == torch::kDouble) ? true : false; + if (type_flag) { + return backward_t(ctx, grad_output); + } else { + return backward_t(ctx, grad_output); + } + } + + template + static torch::autograd::variable_list backward_t( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + torch::autograd::variable_list saved_variables = ctx->get_saved_variables(); + torch::Tensor table_tensor = saved_variables[0]; + torch::Tensor table_info_tensor = saved_variables[1]; + torch::Tensor em_x_tensor = saved_variables[2]; + torch::Tensor em_tensor = saved_variables[3]; + torch::Tensor descriptor_tensor = saved_variables[4]; - return {at::Tensor(), at::Tensor(), dy_dem_tensor, at::Tensor()}; + torch::Tensor dz_dy_dem_x_tensor = + grad_output[0].defined() + ? grad_output[0].contiguous().view( + {em_tensor.size(0), em_tensor.size(1), em_tensor.size(2)}) + : torch::zeros_like(em_tensor); + torch::Tensor dz_dy_tensor = torch::empty_like(descriptor_tensor); + TabulateFusionSeTTebdGradGradForward( + table_tensor, table_info_tensor, em_x_tensor, em_tensor, + dz_dy_dem_x_tensor, descriptor_tensor, dz_dy_tensor); + + return torch::autograd::variable_list{at::Tensor(), at::Tensor(), + at::Tensor(), at::Tensor(), + dz_dy_tensor, at::Tensor()}; } }; @@ -1155,14 +1485,12 @@ class TabulateFusionSeTTebdOp torch::Tensor descriptor_tensor = saved_variables[4]; torch::Tensor dy_tensor = grad_output[0].contiguous(); - // allocate output tensors - torch::Tensor dy_dem_x_tensor = torch::zeros_like(em_x_tensor); - // compute - TabulateFusionSeTTebdGradForward( - table_tensor, table_info_tensor, em_x_tensor, em_tensor, dy_tensor, - descriptor_tensor, dy_dem_x_tensor); + torch::autograd::variable_list dy_dem_tensors = + TabulateFusionSeTTebdGradOp::apply(table_tensor, table_info_tensor, + em_x_tensor, em_tensor, dy_tensor, + descriptor_tensor); - return {at::Tensor(), at::Tensor(), dy_dem_x_tensor, at::Tensor(), + return {at::Tensor(), at::Tensor(), dy_dem_tensors[0], at::Tensor(), at::Tensor()}; } }; diff --git a/source/tests/pt/test_tabulate_fusion_se_a.py b/source/tests/pt/test_tabulate_fusion_se_a.py index 8861f0564e..322a8ca261 100644 --- a/source/tests/pt/test_tabulate_fusion_se_a.py +++ b/source/tests/pt/test_tabulate_fusion_se_a.py @@ -1488,6 +1488,34 @@ def test_backward(self) -> None: rtol=self.prec, ) + def test_second_order_backward(self) -> None: + forward_result = torch.ops.deepmd.tabulate_fusion_se_a( + self.table_tensor, + self.table_info_tensor, + self.em_x_tensor, + self.em_tensor, + self.last_layer_size, + ) + + descriptor_tensor = forward_result[0] + dy_tensor = torch.ones_like(descriptor_tensor, requires_grad=True) + dy_dem_x, dy_dem = torch.autograd.grad( + descriptor_tensor, + (self.em_x_tensor, self.em_tensor), + grad_outputs=dy_tensor, + create_graph=True, + ) + + dz_dy_tensor = torch.autograd.grad( + dy_dem_x.sum() + dy_dem.sum(), + dy_tensor, + )[0] + + self.assertIsNotNone(dz_dy_tensor) + self.assertEqual(dz_dy_tensor.shape, descriptor_tensor.shape) + self.assertTrue(torch.isfinite(dz_dy_tensor).all()) + self.assertGreater(dz_dy_tensor.abs().max().item(), 0.0) + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/test_tabulate_fusion_se_atten.py b/source/tests/pt/test_tabulate_fusion_se_atten.py index 09b06dcc89..4667522bd8 100644 --- a/source/tests/pt/test_tabulate_fusion_se_atten.py +++ b/source/tests/pt/test_tabulate_fusion_se_atten.py @@ -1624,6 +1624,36 @@ def test_backward(self) -> None: self.em_tensor.grad, self.expected_dy_dem, atol=self.prec, rtol=self.prec ) + def test_second_order_backward(self) -> None: + forward_result = torch.ops.deepmd.tabulate_fusion_se_atten( + self.table_tensor, + self.table_info_tensor, + self.em_x_tensor, + self.em_tensor, + self.two_embed_tensor, + self.last_layer_size, + self.is_sorted, + ) + + descriptor_tensor = forward_result[0] + dy_tensor = torch.ones_like(descriptor_tensor, requires_grad=True) + dy_dem_x, dy_dem, dy_dtwo = torch.autograd.grad( + descriptor_tensor, + (self.em_x_tensor, self.em_tensor, self.two_embed_tensor), + grad_outputs=dy_tensor, + create_graph=True, + ) + + dz_dy_tensor = torch.autograd.grad( + dy_dem_x.sum() + dy_dem.sum() + dy_dtwo.sum(), + dy_tensor, + )[0] + + self.assertIsNotNone(dz_dy_tensor) + self.assertEqual(dz_dy_tensor.shape, descriptor_tensor.shape) + self.assertTrue(torch.isfinite(dz_dy_tensor).all()) + self.assertGreater(dz_dy_tensor.abs().max().item(), 0.0) + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/test_tabulate_fusion_se_r.py b/source/tests/pt/test_tabulate_fusion_se_r.py index 34d8dbf182..7e3c913529 100644 --- a/source/tests/pt/test_tabulate_fusion_se_r.py +++ b/source/tests/pt/test_tabulate_fusion_se_r.py @@ -1332,6 +1332,33 @@ def test_backward(self) -> None: self.em_tensor.grad, self.expected_dy_dem, atol=self.prec, rtol=self.prec ) + def test_second_order_backward(self) -> None: + forward_result = torch.ops.deepmd.tabulate_fusion_se_r( + self.table_tensor, + self.table_info_tensor, + self.em_tensor, + self.last_layer_size, + ) + + descriptor_tensor = forward_result[0] + dy_tensor = torch.ones_like(descriptor_tensor, requires_grad=True) + dy_dem = torch.autograd.grad( + descriptor_tensor, + self.em_tensor, + grad_outputs=dy_tensor, + create_graph=True, + )[0] + + dz_dy_tensor = torch.autograd.grad( + dy_dem.sum(), + dy_tensor, + )[0] + + self.assertIsNotNone(dz_dy_tensor) + self.assertEqual(dz_dy_tensor.shape, descriptor_tensor.shape) + self.assertTrue(torch.isfinite(dz_dy_tensor).all()) + self.assertGreater(dz_dy_tensor.abs().max().item(), 0.0) + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/test_tabulate_fusion_se_t.py b/source/tests/pt/test_tabulate_fusion_se_t.py index 50654d557d..abb089c3db 100644 --- a/source/tests/pt/test_tabulate_fusion_se_t.py +++ b/source/tests/pt/test_tabulate_fusion_se_t.py @@ -1748,6 +1748,34 @@ def test_backward(self) -> None: rtol=self.prec, ) + def test_second_order_backward(self) -> None: + forward_result = torch.ops.deepmd.tabulate_fusion_se_t( + self.table_tensor, + self.table_info_tensor, + self.em_x_tensor, + self.em_tensor, + self.last_layer_size, + ) + + descriptor_tensor = forward_result[0] + dy_tensor = torch.ones_like(descriptor_tensor, requires_grad=True) + dy_dem_x, dy_dem = torch.autograd.grad( + descriptor_tensor, + (self.em_x_tensor, self.em_tensor), + grad_outputs=dy_tensor, + create_graph=True, + ) + + dz_dy_tensor = torch.autograd.grad( + dy_dem_x.sum() + dy_dem.sum(), + dy_tensor, + )[0] + + self.assertIsNotNone(dz_dy_tensor) + self.assertEqual(dz_dy_tensor.shape, descriptor_tensor.shape) + self.assertTrue(torch.isfinite(dz_dy_tensor).all()) + self.assertGreater(dz_dy_tensor.abs().max().item(), 0.0) + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/test_tabulate_fusion_se_t_tebd.py b/source/tests/pt/test_tabulate_fusion_se_t_tebd.py index 10bf48c46f..c658c9732b 100644 --- a/source/tests/pt/test_tabulate_fusion_se_t_tebd.py +++ b/source/tests/pt/test_tabulate_fusion_se_t_tebd.py @@ -1913,6 +1913,34 @@ def test_backward(self) -> None: rtol=self.prec, ) + def test_second_order_backward(self) -> None: + forward_result = torch.ops.deepmd.tabulate_fusion_se_t_tebd( + self.table_tensor, + self.table_info_tensor, + self.em_x_tensor, + self.em_tensor, + self.last_layer_size, + ) + + descriptor_tensor = forward_result[0] + dy_tensor = torch.ones_like(descriptor_tensor, requires_grad=True) + dy_dem_x = torch.autograd.grad( + descriptor_tensor, + self.em_x_tensor, + grad_outputs=dy_tensor, + create_graph=True, + )[0] + + dz_dy_tensor = torch.autograd.grad( + dy_dem_x.sum(), + dy_tensor, + )[0] + + self.assertIsNotNone(dz_dy_tensor) + self.assertEqual(dz_dy_tensor.shape, descriptor_tensor.shape) + self.assertTrue(torch.isfinite(dz_dy_tensor).all()) + self.assertGreater(dz_dy_tensor.abs().max().item(), 0.0) + if __name__ == "__main__": unittest.main()