From 2d15c67b2677a85361e99a35600ba998d43eded8 Mon Sep 17 00:00:00 2001 From: Jiho Chu Date: Mon, 19 Feb 2024 14:13:35 +0900 Subject: [PATCH] Modify bn layer for mixed precision Signed-off-by: Jiho Chu --- nntrainer/layers/bn_layer.cpp | 253 +++++++++++----------------------- 1 file changed, 77 insertions(+), 176 deletions(-) diff --git a/nntrainer/layers/bn_layer.cpp b/nntrainer/layers/bn_layer.cpp index e3c179d1f0..c262d59fb6 100644 --- a/nntrainer/layers/bn_layer.cpp +++ b/nntrainer/layers/bn_layer.cpp @@ -174,98 +174,53 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context, /** use hidden_ as temporary tensor before setting the result in hidden */ Tensor t_full = hidden_; Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]); - if (input_.getDataType() == ml::train::TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - TensorDim mu_dim = mu.getDim(); - mu_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor mu32(mu_dim, true); - mu32.copyData(mu); - - TensorDim var_dim = var.getDim(); - var_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor var32(var_dim, true); - var32.copyData(var); - - TensorDim gamma_dim = gamma.getDim(); - gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor gamma32(gamma_dim, true); - gamma32.copyData(gamma); - - TensorDim beta_dim = beta.getDim(); - beta_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor beta32(beta_dim, true); - beta32.copyData(beta); - - TensorDim input_dim = input_.getDim(); - input_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor input_32(input_dim, true); - input_32.copyData(input_); - - TensorDim hidden_dim = hidden_.getDim(); - hidden_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor hidden_32(hidden_dim, true); - hidden_32.copyData(hidden_); - Tensor t_full32 = hidden_32; - - TensorDim deviation_dim = deviation.getDim(); - deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor deviation32(deviation_dim, true); - deviation32.copyData(deviation); - - TensorDim dim_invstd = invstd.getDim(); - dim_invstd.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor invstd32(dim_invstd, true); - invstd32.copyData(invstd); - - TensorDim t_reduced_dim = t_reduced.getDim(); - t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor t_reduced32(t_reduced_dim, true); - t_reduced32.copyData(t_reduced); - - TensorDim cvar_dim = cvar.getDim(); - cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor cvar32(cvar_dim, true); - cvar32.copyData(cvar); + + const auto &in_type = input_.getDataType(); + if (in_type != mu.getDataType()) { + // It calculates with activation data type + Tensor mu_ = mu.clone(in_type); + Tensor var_ = var.clone(in_type); + Tensor gamma_ = gamma.clone(in_type); + Tensor beta_ = beta.clone(in_type); + Tensor deviation_ = deviation.clone(in_type); + Tensor invstd_ = invstd.clone(in_type); + Tensor t_reduced_ = t_reduced.clone(in_type); + Tensor cvar_ = cvar.clone(in_type); if (training) { - input_32.average(axes_to_reduce, t_reduced32); - input_32.subtract(t_reduced32, deviation32); + input_.average(axes_to_reduce, t_reduced_); + input_.subtract(t_reduced_, deviation_); - mu32.multiply_i(momentum); - mu32.add_i(t_reduced32, 1 - momentum); + mu_.multiply_i(momentum); + mu_.add_i(t_reduced_, 1 - momentum); - deviation32.pow(2.0f, t_full32); - t_full32.average(axes_to_reduce, cvar32); + deviation_.pow(2.0f, t_full); + t_full.average(axes_to_reduce, cvar_); - var32.multiply_i(momentum); - var32.add_i(cvar32, 1 - momentum); + var_.multiply_i(momentum); + var_.add_i(cvar_, 1 - momentum); - cvar32.add_i(epsilon); - cvar32.pow(-0.5f, invstd32); + cvar_.add_i(epsilon); + cvar_.pow(-0.5f, invstd_); } else { - input_32.subtract(mu32, deviation32); + input_.subtract(mu_, deviation_); /** @todo do below 2 lines only for first iteration */ - var32.add(epsilon, invstd32); - invstd32.pow_i(-0.5f); + var_.add(epsilon, invstd_); + invstd_.pow_i(-0.5f); } - deviation32.multiply(invstd32, hidden_32); - hidden_32.multiply_i(gamma32); - hidden_32.add_i(beta32); - - mu.copyData(mu32); - var.copyData(var32); - gamma.copyData(gamma32); - beta.copyData(beta32); - input_.copyData(input_32); - hidden_.copyData(hidden_32); - deviation.copyData(deviation32); - invstd.copyData(invstd32); - t_reduced.copyData(t_reduced32); - cvar.copyData(cvar32); -#else - throw std::runtime_error("enable-fp16 is not enabled"); -#endif + deviation_.multiply(invstd_, hidden_); + hidden_.multiply_i(gamma_); + hidden_.add_i(beta_); + + mu.copyData(mu_); + var.copyData(var_); + gamma.copyData(gamma_); + beta.copyData(beta_); + deviation.copyData(deviation_); + invstd.copyData(invstd_); + t_reduced.copyData(t_reduced_); + cvar.copyData(cvar_); } else { if (training) { input_.average(axes_to_reduce, t_reduced); @@ -306,96 +261,53 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) { Tensor &t_reduced = context.getTensor(wt_idx[BNParams::t_reduced]); Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full]); - if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - TensorDim gamma_dim = gamma.getDim(); - gamma_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor gamma32(gamma_dim, true); - gamma32.copyData(gamma); - - TensorDim deriv_dim = deriv.getDim(); - deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor deriv32(deriv_dim, true); - deriv32.copyData(deriv); - - TensorDim dx_dim = dx.getDim(); - dx_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor dx32(dx_dim, true); - dx32.copyData(dx); - - TensorDim deviation_dim = deviation.getDim(); - deviation_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor deviation32(deviation_dim, true); - deviation32.copyData(deviation); - - TensorDim invstd_dim = invstd.getDim(); - invstd_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor invstd32(invstd_dim, true); - invstd32.copyData(invstd); - - TensorDim cvar_dim = cvar.getDim(); - cvar_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor cvar32(cvar_dim, true); - cvar32.copyData(cvar); - - TensorDim t_reduced_dim = t_reduced.getDim(); - t_reduced_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor t_reduced32(t_reduced_dim, true); - t_reduced32.copyData(t_reduced); - - TensorDim t_full_dim = t_full.getDim(); - t_full_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor t_full32(t_full_dim, true); - t_full32.copyData(t_full); - - deviation32.multiply(deriv32, t_full32); - t_full32.average(axes_to_reduce, t_reduced32); - t_reduced32.divide_i(cvar32); - deviation32.multiply_i(t_reduced32); + + const auto &deriv_type = deriv.getDataType(); + + if (deriv_type != gamma.getDataType()) { + Tensor gamma_ = gamma.clone(deriv_type); + Tensor deviation_ = deviation.clone(deriv_type); + Tensor invstd_ = invstd.clone(deriv_type); + Tensor cvar_ = cvar.clone(deriv_type); + Tensor t_reduced_ = t_reduced.clone(deriv_type); + Tensor t_full_ = t_full.clone(deriv_type); + + deviation_.multiply(deriv, t_full_); + t_full_.average(axes_to_reduce, t_reduced_); + t_reduced_.divide_i(cvar_); + deviation_.multiply_i(t_reduced_); if (context.getTrainable()) { /** * This calculates dgamma tensor. */ Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]); - TensorDim dgamma_dim = dgamma.getDim(); - dgamma_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor dgamma32(dgamma_dim, true); - dgamma32.copyData(dgamma); - - t_full32.multiply_i(invstd32); - t_full32.sum(axes_to_reduce, dgamma32); - dgamma.copyData(dgamma32); + Tensor dgamma_ = dgamma.clone(deriv_type); + t_full_.multiply_i(invstd_); + t_full_.sum(axes_to_reduce, dgamma_); + dgamma.copyData(dgamma_); /** * This implementation depends on the pre-calculated dbeta calculated. */ Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]); - TensorDim dbeta_dim = dbeta.getDim(); - dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor dbeta32(dbeta_dim, true); - dbeta32.copyData(dbeta); - dbeta32.divide(divider, t_reduced32); + Tensor dbeta_ = dbeta.clone(deriv_type); + dbeta_.divide(divider, t_reduced_); } else { - deriv32.average(axes_to_reduce, t_reduced32); + deriv.average(axes_to_reduce, t_reduced_); } - deriv32.subtract(t_reduced32, dx32); - dx32.subtract_i(deviation32); - - invstd32.multiply_i(gamma32); - dx32.multiply_i(invstd32); - - gamma.copyData(gamma32); - dx.copyData(dx32); - deviation.copyData(deviation32); - invstd.copyData(invstd32); - cvar.copyData(cvar32); - t_reduced.copyData(t_reduced32); - t_full.copyData(t_full32); -#else - throw std::runtime_error("enable-fp16 is not enabled"); -#endif + deriv.subtract(t_reduced_, dx); + dx.subtract_i(deviation_); + + invstd_.multiply_i(gamma_); + + gamma.copyData(gamma_); + deviation.copyData(deviation_); + invstd.copyData(invstd_); + cvar.copyData(cvar_); + t_reduced.copyData(t_reduced_); + t_full.copyData(t_full_); } else { deviation.multiply(deriv, t_full); t_full.average(axes_to_reduce, t_reduced); @@ -431,25 +343,14 @@ void BatchNormalizationLayer::calcGradient(RunLayerContext &context) { /** dgamma is calculated in calcDerivative. dbeta is calculated here */ Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]); const Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX); - if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - TensorDim dbeta_dim = dbeta.getDim(); - dbeta_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor dbeta32(dbeta_dim, true); - dbeta32.copyData(dbeta); - - TensorDim deriv_dim = deriv.getDim(); - deriv_dim.setDataType(ml::train::TensorDim::DataType::FP32); - Tensor deriv32(deriv_dim, true); - deriv32.copyData(deriv); - - deriv32.sum(axes_to_reduce, dbeta32); - dbeta.copyData(dbeta32); -#else - throw std::runtime_error("enable-fp16 is not enabled"); -#endif - } else { + + const auto &deriv_type = deriv.getDataType(); + if (deriv_type == dbeta.getDataType()) { deriv.sum(axes_to_reduce, dbeta); + } else { + Tensor dbeta_ = dbeta.clone(deriv_type); + deriv.sum(axes_to_reduce, dbeta_); + dbeta.copyData(dbeta_); } }