Skip to content

Files

Latest commit

9f9595b · Jun 2, 2020

History

History
279 lines (250 loc) · 12.5 KB

changing_intermediate_types.md

File metadata and controls

279 lines (250 loc) · 12.5 KB

index

Let's explore what the internal codegen representation of mm looks like

mm_fused

func @mm_fused(%outArg:memref<2048x2048xf32>, %inA:memref<2048x2048xf32>, %inB:memref<2048x2048xf32>) {
    %A = rise.in %inA : !rise.array<2048, array<2048, scalar<f32>>>
    %B = rise.in %inB : !rise.array<2048, array<2048, scalar<f32>>>

    %m1fun = rise.lambda (%arow : !rise.array<2048, scalar<f32>>) -> !rise.array<2048, scalar<f32>> {
        %m2fun = rise.lambda (%bcol : !rise.array<2048, scalar<f32>>) -> !rise.array<2048, scalar<f32>> {

            //Zipping
            %zipFun = rise.zip #rise.nat<2048> #rise.scalar<f32> #rise.scalar<f32>
            %zippedArrays = rise.apply %zipFun, %arow, %bcol

            //Reduction
            %reductionLambda = rise.lambda (%tuple : !rise.tuple<scalar<f32>, scalar<f32>>, %acc : !rise.scalar<f32>) -> !rise.scalar<f32> {

                %fstFun = rise.fst #rise.scalar<f32> #rise.scalar<f32>
                %sndFun = rise.snd #rise.scalar<f32> #rise.scalar<f32>

                %fst = rise.apply %fstFun, %tuple
                %snd = rise.apply %sndFun, %tuple

                %result = rise.embed(%fst, %snd, %acc) {
                       %product = mulf %fst, %snd :f32
                       %result = addf %product, %acc : f32
                       rise.return %result : f32
                }
                rise.return %result : !rise.scalar<f32>
            }

            %initializer = rise.literal #rise.lit<0.0>
            %reduceFun = rise.reduceSeq {to = "loop"}  #rise.nat<2048> #rise.tuple<scalar<f32>, scalar<f32>> #rise.scalar<f32>
            %result = rise.apply %reduceFun, %reductionLambda, %initializer, %zippedArrays

            rise.return %result : !rise.scalar<f32>
        }
        %m2 = rise.mapSeq {to = "loop"}  #rise.nat<2048> #rise.array<2048, scalar<f32>> #rise.array<2048, scalar<f32>>
        %result = rise.apply %m2, %m2fun, %B
        rise.return %result : !rise.array<2048, array<2048, scalar<f32>>>
    }
    %m1 = rise.mapSeq {to = "loop"}  #rise.nat<2048> #rise.array<2048, scalar<f32>> #rise.array<2048, scalar<f32>>
    %result = rise.apply %m1, %m1fun, %A
    rise.out %outArg <- %result
    return
}}
        |       Lowering to Intermediate (this is for debugging purposes and not the result of the lowering pass)
        |           rise.codegen.*
        V
func @mm_fused(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2048x2048xf32>) {
  %0 = "rise.in"(%arg1) : (memref<2048x2048xf32>) -> !rise.array<2048, array<2048, scalar<f32>>>
  %1 = "rise.in"(%arg2) : (memref<2048x2048xf32>) -> !rise.array<2048, array<2048, scalar<f32>>>
  %2 = "rise.codegen.cast"(%arg0) : (memref<2048x2048xf32>) -> !rise.array<2048, array<2048, scalar<f32>>>
  %c0 = constant 0 : index
  %c0_0 = constant 0 : index
  %c2048 = constant 2048 : index
  %c1 = constant 1 : index
  scf.for %arg3 = %c0_0 to %c2048 step %c1 {
    %3 = "rise.codegen.idx"(%0, %arg3) : (!rise.array<2048, array<2048, scalar<f32>>>, index) -> !rise.array<2048, scalar<f32>>
    %4 = "rise.codegen.idx"(%2, %arg3) : (!rise.array<2048, array<2048, scalar<f32>>>, index) -> !rise.array<2048, scalar<f32>>
    %c0_1 = constant 0 : index
    %c0_2 = constant 0 : index
    %c2048_3 = constant 2048 : index
    %c1_4 = constant 1 : index
    scf.for %arg4 = %c0_2 to %c2048_3 step %c1_4 {
      %5 = "rise.codegen.idx"(%1, %arg4) : (!rise.array<2048, array<2048, scalar<f32>>>, index) -> !rise.array<2048, scalar<f32>>
      %6 = "rise.codegen.idx"(%4, %arg4) : (!rise.array<2048, scalar<f32>>, index) -> !rise.scalar<f32>
      %7 = "rise.codegen.zip"(%3, %5) : (!rise.array<2048, scalar<f32>>, !rise.array<2048, scalar<f32>>) -> !rise.array<2048, tuple<scalar<f32>, scalar<f32>>>
      %c0_5 = constant 0 : index
      %8 = "rise.embed"() ( {
        %cst = constant 0.000000e+00 : f32
        "rise.return"(%cst) : (f32) -> ()
      }) : () -> !rise.scalar<f32>
      "rise.codegen.assign"(%8, %6) : (!rise.scalar<f32>, !rise.scalar<f32>) -> ()
      %c0_6 = constant 0 : index
      %c2048_7 = constant 2048 : index
      %c1_8 = constant 1 : index
      scf.for %arg5 = %c0_6 to %c2048_7 step %c1_8 {
        %9 = "rise.codegen.idx"(%7, %arg5) : (!rise.array<2048, tuple<scalar<f32>, scalar<f32>>>, index) -> !rise.tuple<scalar<f32>, scalar<f32>>
        %10 = "rise.codegen.fst"(%9) : (!rise.tuple<scalar<f32>, scalar<f32>>) -> f32
        %11 = "rise.codegen.snd"(%9) : (!rise.tuple<scalar<f32>, scalar<f32>>) -> f32
        %12 = "rise.embed"(%10, %11, %6) ( {
        ^bb0(%arg6: f32, %arg7: f32, %arg8: f32):  // no predecessors
          %13 = mulf %arg6, %arg7 : f32
          %14 = addf %13, %arg8 : f32
          "rise.return"(%14) : (f32) -> ()
        }) : (f32, f32, !rise.scalar<f32>) -> !rise.scalar<f32>
        "rise.codegen.assign"(%12, %6) : (!rise.scalar<f32>, !rise.scalar<f32>) -> ()
      }
    }
  }
  return
}}

mm

func @mm(%outArg:memref<4x4xf32>, %inA:memref<4x4xf32>, %inB:memref<4x4xf32>) {
    %A = rise.in %inA : !rise.array<4, array<4, scalar<f32>>>
    %B = rise.in %inB : !rise.array<4, array<4, scalar<f32>>>

    %f1 = rise.lambda (%arow : !rise.array<4, scalar<f32>>) -> !rise.array<4, scalar<f32>> {
        %f2 = rise.lambda (%bcol : !rise.array<4, scalar<f32>>) -> !rise.array<4, scalar<f32>> {

            //Zipping
            %zipFun = rise.zip #rise.nat<4> #rise.scalar<f32> #rise.scalar<f32>
            %zippedArrays = rise.apply %zipFun, %arow, %bcol

            //Multiply
            %f = rise.lambda (%floatTuple : !rise.tuple<scalar<f32>, scalar<f32>>) -> !rise.scalar<f32> {
                %fstFun = rise.fst #rise.scalar<f32> #rise.scalar<f32>
                %sndFun = rise.snd #rise.scalar<f32> #rise.scalar<f32>

                %fst = rise.apply %fstFun, %floatTuple
                %snd = rise.apply %sndFun, %floatTuple
                %result = rise.embed(%fst, %snd) {
                    %result = mulf %fst, %snd : f32
                    rise.return %result : f32
                }

                rise.return %result : !rise.scalar<f32>
            }
            %map = rise.mapPar #rise.nat<4> #rise.tuple<scalar<f32>, scalar<f32>> #rise.scalar<f32>
            %multipliedArray = rise.apply %map, %f, %zippedArrays

            //Reduction
            %reductionAdd = rise.lambda (%summand0 : !rise.scalar<f32>, %summand1 : !rise.scalar<f32>) -> !rise.scalar<f32> {
                %result = rise.embed(%summand0, %summand1) {
                       %result = addf %summand0, %summand1 : f32
                       rise.return %result : f32
                }
                rise.return %result : !rise.scalar<f32>
            }
            %initializer = rise.literal #rise.lit<0.0>
            %reduce10Ints = rise.reduceSeq #rise.nat<4> #rise.scalar<f32> #rise.scalar<f32>
            %result = rise.apply %reduce10Ints, %reductionAdd, %initializer, %multipliedArray

            rise.return %result : !rise.scalar<f32>
        }
        %m2 = rise.mapPar #rise.nat<4> #rise.array<4, scalar<f32>> #rise.array<4, scalar<f32>>
        %result = rise.apply %m2, %f2, %B
        rise.return %result : !rise.array<4, array<4, scalar<f32>>>
    }
    %m1 = rise.mapPar #rise.nat<4> #rise.array<4, scalar<f32>> #rise.array<4, scalar<f32>>
    %result = rise.apply %m1, %f1, %A
    rise.out %outArg <- %result
    return
        |       Lowering to Intermediate (this is for debugging purposes and not the result of the lowering pass)
        |           rise.codegen.*
        V
func @mm(%arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>, %arg2: memref<4x4xf32>) {
  %0 = "rise.in"(%arg1) : (memref<4x4xf32>) -> !rise.array<4, array<4, scalar<f32>>>
  %1 = "rise.in"(%arg2) : (memref<4x4xf32>) -> !rise.array<4, array<4, scalar<f32>>>
  %2 = "rise.codegen.cast"(%arg0) : (memref<4x4xf32>) -> !rise.array<4, array<4, scalar<f32>>>
  %c0 = constant 0 : index
  %c4 = constant 4 : index
  %c1 = constant 1 : index
  scf.for %arg3 = %c0 to %c4 step %c1 {
    %3 = "rise.codegen.idx"(%0, %arg3) : (!rise.array<4, array<4, scalar<f32>>>, index) -> !rise.array<4, scalar<f32>>
    %4 = "rise.codegen.idx"(%2, %arg3) : (!rise.array<4, array<4, scalar<f32>>>, index) -> !rise.array<4, scalar<f32>>
    %c0_0 = constant 0 : index
    %c4_1 = constant 4 : index
    %c1_2 = constant 1 : index
    scf.for %arg4 = %c0_0 to %c4_1 step %c1_2 {
      %5 = "rise.codegen.idx"(%1, %arg4) : (!rise.array<4, array<4, scalar<f32>>>, index) -> !rise.array<4, scalar<f32>>
      %6 = "rise.codegen.idx"(%4, %arg4) : (!rise.array<4, scalar<f32>>, index) -> !rise.scalar<f32>
      %7 = "rise.embed"() ( {
        %10 = alloc() : memref<4xf32>
        "rise.return"(%10) : (memref<4xf32>) -> ()
      }) : () -> !rise.array<4, scalar<f32>>
      %8 = "rise.codegen.zip"(%3, %5) : (!rise.array<4, scalar<f32>>, !rise.array<4, scalar<f32>>) -> !rise.array<4, tuple<scalar<f32>, scalar<f32>>>
      %c0_3 = constant 0 : index
      %c4_4 = constant 4 : index
      %c1_5 = constant 1 : index
      scf.for %arg5 = %c0_3 to %c4_4 step %c1_5 {
        %10 = "rise.codegen.idx"(%8, %arg5) : (!rise.array<4, tuple<scalar<f32>, scalar<f32>>>, index) -> !rise.tuple<scalar<f32>, scalar<f32>>
        %11 = "rise.codegen.idx"(%7, %arg5) : (!rise.array<4, scalar<f32>>, index) -> !rise.scalar<f32>
        %12 = "rise.codegen.fst"(%10) : (!rise.tuple<scalar<f32>, scalar<f32>>) -> f32
        %13 = "rise.codegen.snd"(%10) : (!rise.tuple<scalar<f32>, scalar<f32>>) -> f32
        %14 = "rise.embed"(%12, %13) ( {
        ^bb0(%arg6: f32, %arg7: f32):  // no predecessors
          %15 = mulf %arg6, %arg7 : f32
          "rise.return"(%15) : (f32) -> ()
        }) : (f32, f32) -> !rise.scalar<f32>
        "rise.codegen.assign"(%14, %11) : (!rise.scalar<f32>, !rise.scalar<f32>) -> ()
      }
      %c0_6 = constant 0 : index
      %9 = "rise.embed"() ( {
        %cst = constant 0.000000e+00 : f32
        "rise.return"(%cst) : (f32) -> ()
      }) : () -> !rise.scalar<f32>
      "rise.codegen.assign"(%9, %6) : (!rise.scalar<f32>, !rise.scalar<f32>) -> ()
      %c0_7 = constant 0 : index
      %c4_8 = constant 4 : index
      %c1_9 = constant 1 : index
      scf.for %arg5 = %c0_7 to %c4_8 step %c1_9 {
        %10 = "rise.codegen.idx"(%7, %arg5) : (!rise.array<4, scalar<f32>>, index) -> !rise.scalar<f32>
        %11 = "rise.embed"(%10, %6) ( {
        ^bb0(%arg6: f32, %arg7: f32):  // no predecessors
          %12 = addf %arg6, %arg7 : f32
          "rise.return"(%12) : (f32) -> ()
        }) : (!rise.scalar<f32>, !rise.scalar<f32>) -> !rise.scalar<f32>
        "rise.codegen.assign"(%11, %6) : (!rise.scalar<f32>, !rise.scalar<f32>) -> ()
      }
    }
  }
  return
}

reduce

func @reduce(%outArg:memref<f32>, %inArg:memref<1024xf32>) {
    %array0 = rise.in %inArg : !rise.array<1024, scalar<f32>>

    %reductionAdd = rise.lambda (%summand0 : !rise.scalar<f32>, %summand1 : !rise.scalar<f32>) -> !rise.scalar<f32> {
        %result = rise.embed(%summand0, %summand1) {
            %result = addf %summand0, %summand1 : f32
            rise.return %result : f32
        }
        rise.return %result : !rise.scalar<f32>
    }
    %initializer = rise.literal #rise.lit<0.0>
    %reduce4Ints = rise.reduceSeq #rise.nat<1024> #rise.scalar<f32> #rise.scalar<f32>
    %result = rise.apply %reduce4Ints, %reductionAdd, %initializer, %array0
    rise.out %outArg <- %result
    return
}
        |       Lowering to Intermediate (this is for debugging purposes and not the result of the lowering pass)
        |           rise.codegen.*
        V
  func @reduce(%arg0: memref<f32>, %arg1: memref<1024xf32>) {
    %0 = "rise.in"(%arg1) : (memref<1024xf32>) -> !rise.array<1024, scalar<f32>>
    %1 = "rise.codegen.cast"(%arg0) : (memref<f32>) -> !rise.scalar<f32>
    %c0 = constant 0 : index
    %2 = "rise.embed"() ( {
      %cst = constant 0.000000e+00 : f32
      "rise.return"(%cst) : (f32) -> ()
    }) : () -> !rise.scalar<f32>
    "rise.codegen.assign"(%2, %1) : (!rise.scalar<f32>, !rise.scalar<f32>) -> ()
    %c0_0 = constant 0 : index
    %c1024 = constant 1024 : index
    %c1 = constant 1 : index
    scf.for %arg2 = %c0_0 to %c1024 step %c1 {
      %4 = "rise.codegen.idx"(%0, %arg2) : (!rise.array<1024, scalar<f32>>, index) -> !rise.scalar<f32>
      %5 = "rise.embed"(%4, %1) ( {
      ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
        %6 = addf %arg3, %arg4 : f32
        "rise.return"(%6) : (f32) -> ()
      }) : (!rise.scalar<f32>, !rise.scalar<f32>) -> !rise.scalar<f32>
      "rise.codegen.assign"(%5, %1) : (!rise.scalar<f32>, !rise.scalar<f32>) -> ()
    }
    return
  }