diff --git a/src/Nncase.Importer/Onnx/Normalization.cs b/src/Nncase.Importer/Onnx/Normalization.cs index 25aab015b..f4297bd84 100644 --- a/src/Nncase.Importer/Onnx/Normalization.cs +++ b/src/Nncase.Importer/Onnx/Normalization.cs @@ -38,6 +38,15 @@ private Expr VisitInstanceNormalization(in NodeProto op) return F.NN.InstanceNormalization(input, scale, bias, eps); } + private Expr VisitLayerNormalization(NodeProto op) + { + var input = GetInputExpr(op, 0); + var (scale, bias) = GetInputExprs(op, 1, 2); + var eps = GetFloatAttribute(op, "epsilon", 1e-05f); + var axis = GetIntAttribute(op, "axis", -1); + return F.NN.LayerNorm(checked((int)axis), eps, input, scale, bias); + } + private Expr VisitLpNormalization(in NodeProto op) { var input = GetInputExpr(op, 0); diff --git a/src/Nncase.Importer/Onnx/OnnxImporter.cs b/src/Nncase.Importer/Onnx/OnnxImporter.cs index b87178604..fee1027a3 100644 --- a/src/Nncase.Importer/Onnx/OnnxImporter.cs +++ b/src/Nncase.Importer/Onnx/OnnxImporter.cs @@ -142,6 +142,7 @@ private void Visit(NodeProto op) "HardSwish" => VisitHardSwish(op), "Identity" => VisitIdentity(op), "InstanceNormalization" => VisitInstanceNormalization(op), + "LayerNormalization" => VisitLayerNormalization(op), "LpNormalization" => VisitLpNormalization(op), "LeakyRelu" => VisitLeakyRelu(op), "Less" => VisitCompare(op, CompareOp.LowerThan),