diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 8513285524..9728499a04 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,6 +1,9 @@ *.cmake @nsmithtt @sdjordjevicTT *.fbs @tapspatel @nsmithtt *.txt @nsmithtt @sdjordjevicTT +*.md @nsmithtt +/docs/ @nsmithtt +/env/ @nsmithtt /.github/ @vmilosevic @tapspatel /include/ttmlir/Conversion/TTIRToTTNN/ @sdjordjevicTT @svuckovicTT @mtopalovicTT @rpavlovicTT @jserbedzijaTT @jnie-TT /include/ttmlir/Conversion/TTNNToEmitC/ @svuckovicTT @rpavlovicTT @sdjordjevicTT @mtopalovicTT @jserbedzijaTT @@ -26,3 +29,4 @@ /test/ttmlir/Silicon/TTNN/optimizer/ @nobradovictt @odjuricicTT /test/unittests/Optimizer @nobradovictt @odjuricicTT /tools/explorer/ @odjuricicTT @nobradovictt @vprajapati-tt +/tools/ @svuckovicTT @mtopalovicTT diff --git a/.github/Dockerfile.base b/.github/Dockerfile.base index c0a01e6d69..e6fc33757c 100644 --- a/.github/Dockerfile.base +++ b/.github/Dockerfile.base @@ -28,7 +28,11 @@ RUN apt-get update && apt-get install -y \ graphviz \ patchelf \ libyaml-cpp-dev \ - libboost-all-dev + libboost-all-dev \ + curl \ + jq \ + sudo \ + gh # Install clang 17 RUN wget https://apt.llvm.org/llvm.sh && \ diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index ade377c06a..68db5d1cff 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -40,6 +40,56 @@ jobs: echo "DOCKER_CI_IMAGE $DOCKER_CI_IMAGE" echo "docker-image=$DOCKER_CI_IMAGE" >> "$GITHUB_OUTPUT" + lint: + needs: build-image + timeout-minutes: 120 + strategy: + fail-fast: false + name: Lint (clang-tidy) + runs-on: ubuntu-latest + container: + image: ${{ needs.build-image.outputs.docker-image }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set reusable strings + id: strings + shell: bash + run: | + echo "work-dir=$(pwd)" >> "$GITHUB_OUTPUT" + echo "build-output-dir=$(pwd)/build" >> "$GITHUB_OUTPUT" + echo "install-output-dir=$(pwd)/install" >> "$GITHUB_OUTPUT" + + - name: Git safe dir + run: git config --global --add safe.directory ${{ steps.strings.outputs.work-dir }} + + - name: Configure CMake + shell: bash + run: | + source env/activate + cmake -G Ninja \ + -B ${{ steps.strings.outputs.build-output-dir }} \ + -DCMAKE_CXX_COMPILER=clang++-17 \ + -DCMAKE_C_COMPILER=clang-17 \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=${{ steps.strings.outputs.install-output-dir }} \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DTTMLIR_ENABLE_RUNTIME=ON \ + -DTTMLIR_ENABLE_RUNTIME_TESTS=ON \ + -DTTMLIR_ENABLE_STABLEHLO=ON \ + -DTTMLIR_ENABLE_OP_MODEL=ON \ + -S ${{ steps.strings.outputs.work-dir }} + + - name: Lint + id: lint + shell: bash + run: | + source env/activate + cmake --build ${{ steps.strings.outputs.build-output-dir }} -- clang-tidy + build-ttmlir: needs: build-image timeout-minutes: 120 @@ -47,8 +97,9 @@ jobs: fail-fast: false matrix: build: [ - {runs-on: ubuntu-latest, enable_perf: OFF, name: "run", ttrt_flags: ""}, - {runs-on: ubuntu-latest, enable_perf: ON, name: "perf", ttrt_flags: ""}, + {runs-on: ubuntu-latest, enable_perf: OFF, enable_op_model: OFF, name: "run", ttrt_flags: ""}, + {runs-on: ubuntu-latest, enable_perf: ON, enable_op_model: OFF, name: "perf", ttrt_flags: ""}, + {runs-on: ubuntu-latest, enable_perf: OFF, enable_op_model: ON, name: "op_model" , ttrt_flags: ""} ] name: Build tt-mlir @@ -66,11 +117,22 @@ jobs: - name: Set reusable strings id: strings shell: bash + env: + job-name: "Build tt-mlir (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.enable_op_model }}, ${{ matrix.build.name }})" run: | echo "work-dir=$(pwd)" >> "$GITHUB_OUTPUT" echo "build-output-dir=$(pwd)/build" >> "$GITHUB_OUTPUT" echo "install-output-dir=$(pwd)/install" >> "$GITHUB_OUTPUT" + # Github job context unfortunately doesn't contain job_id, this is the workaround how to fetch it using GH API + echo "Expected job name: ${{ env.job-name }}" + JOB_ID=$(curl -s -H "Authorization: token ${{ secrets.GH_TOKEN }}" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/attempts/${{ github.run_attempt }}/jobs" | \ + jq -r '.jobs[] | select(.name | contains("${{ env.job-name }}")) | .id ') + echo "Current job id: $JOB_ID" + echo "job-id=$JOB_ID" >> "$GITHUB_OUTPUT" + echo "test_report_path=report_$JOB_ID.xml" >> "$GITHUB_OUTPUT" + - name: Git safe dir run: git config --global --add safe.directory ${{ steps.strings.outputs.work-dir }} @@ -78,7 +140,7 @@ jobs: uses: hendrikmuhs/ccache-action@v1.2 with: create-symlink: true - key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-${{ env.SDK_VERSION }} + key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }}-${{ env.SDK_VERSION }} # Build project @@ -97,6 +159,7 @@ jobs: -DTTMLIR_ENABLE_RUNTIME_TESTS=ON \ -DTT_RUNTIME_ENABLE_PERF_TRACE=${{ matrix.build.enable_perf }} \ -DTTMLIR_ENABLE_STABLEHLO=ON \ + -DTTMLIR_ENABLE_OP_MODEL=${{ matrix.build.enable_op_model }} \ -S ${{ steps.strings.outputs.work-dir }} - name: Build @@ -106,14 +169,6 @@ jobs: cmake --build ${{ steps.strings.outputs.build-output-dir }} cmake --install ${{ steps.strings.outputs.build-output-dir }} --component Test - - name: Lint - id: lint - shell: bash - if: matrix.build.enable_perf == 'OFF' - run: | - source env/activate - cmake --build ${{ steps.strings.outputs.build-output-dir }} -- clang-tidy - - name: Unique-ify clang-tidy fixes shell: bash if: failure() && steps.lint.outcome == 'failure' @@ -143,18 +198,19 @@ jobs: run: | source env/activate cmake --build ${{ steps.strings.outputs.build-output-dir }} -- check-ttmlir + cp build/test/report.xml ${{ steps.strings.outputs.test_report_path }} - name: Upload Test Report uses: actions/upload-artifact@v4 with: - name: test-reports-${{ matrix.build.runs-on }}-perf-${{ matrix.build.enable_perf }} - path: build/test/report.xml + name: test-reports-${{ matrix.build.runs-on }}-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }} + path: ${{ steps.strings.outputs.test_report_path }} - name: Show Test Report uses: mikepenz/action-junit-report@v4 if: success() || failure() with: - report_paths: build/test/report.xml + report_paths: ${{ steps.strings.outputs.test_report_path }} check_name: MLIR Tests # Build and upload ttrt @@ -201,7 +257,7 @@ jobs: run-tests: - timeout-minutes: 30 + timeout-minutes: 45 needs: - build-image - build-ttmlir @@ -214,6 +270,7 @@ jobs: {runs-on: n300, enable_perf: OFF, name: "run", ttrt_flags: "--non-zero"}, {runs-on: n300, enable_perf: ON, name: "perf"}, ] + name: "run-tests (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.name }})" runs-on: - in-service @@ -237,11 +294,23 @@ jobs: - name: Set reusable strings id: strings shell: bash + env: + job-name: "run-tests (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.name }})" run: | echo "work-dir=$(pwd)" >> "$GITHUB_OUTPUT" echo "build-output-dir=$(pwd)/build" >> "$GITHUB_OUTPUT" echo "install-output-dir=$(pwd)/install" >> "$GITHUB_OUTPUT" + # Github job context unfortunately doesn't contain job_id, this is the workaround how to fetch it using GH API + echo "Expected job name: ${{ env.job-name }}" + JOB_ID=$(curl -s -H "Authorization: token ${{ secrets.GH_TOKEN }}" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/attempts/${{ github.run_attempt }}/jobs" | \ + jq -r '.jobs[] | select(.name | contains("${{ env.job-name }}")) | .id ') + echo "Current job id: $JOB_ID" + + echo "job-id=$JOB_ID" >> "$GITHUB_OUTPUT" + echo "test_report_path=report_$JOB_ID.xml" >> "$GITHUB_OUTPUT" + - name: Git safe dir run: git config --global --add safe.directory ${{ steps.strings.outputs.work-dir }} @@ -303,19 +372,27 @@ jobs: run: | source env/activate ttrt ${{ matrix.build.name }} ${{ matrix.build.ttrt_flags }} ${{ steps.strings.outputs.build-output-dir }}/test/ttmlir/Silicon/TTNN/perf_unit + cp ttrt_report.xml ${{ steps.strings.outputs.test_report_path }} - - name: Upload ttrt test report + - name: Upload ttrt test report json if: always() uses: actions/upload-artifact@v4 with: name: ${{ matrix.build.runs-on }}_${{ matrix.build.name }}_results.json path: ${{ matrix.build.name }}_results.json + - name: Upload Test Report xml + uses: actions/upload-artifact@v4 + if: success() || failure() + with: + name: test-reports-${{ matrix.build.runs-on }}-${{ matrix.test_group_id }} + path: ${{ steps.strings.outputs.test_report_path }} + - name: Show Test Report uses: mikepenz/action-junit-report@v4 if: success() || failure() with: - report_paths: ttrt_report.xml + report_paths: ${{ steps.strings.outputs.test_report_path }} check_name: TTRT ${{ matrix.build.runs-on }} ${{ matrix.build.name }} Tests run-ttrt-tests: @@ -346,6 +423,7 @@ jobs: - /opt/tt_metal_infra/provisioning/provisioning_env:/opt/tt_metal_infra/provisioning/provisioning_env steps: + - uses: actions/checkout@v4 with: fetch-depth: 0 @@ -353,11 +431,22 @@ jobs: - name: Set reusable strings id: strings shell: bash + env: + job-name: "${{ github.job }} (${{ matrix.build.runs-on }}, ${{ matrix.build.enable_perf }}, ${{ matrix.build.name }})" run: | echo "work-dir=$(pwd)" >> "$GITHUB_OUTPUT" echo "build-output-dir=$(pwd)/build" >> "$GITHUB_OUTPUT" echo "install-output-dir=$(pwd)/install" >> "$GITHUB_OUTPUT" + # Github job context unfortunately doesn't contain job_id, this is the workaround how to fetch it using GH API + echo "Expected job name: ${{ env.job-name }}" + JOB_ID=$(curl -s -H "Authorization: token ${{ secrets.GH_TOKEN }}" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/attempts/${{ github.run_attempt }}/jobs" | \ + jq -r '.jobs[] | select(.name | contains("${{ env.job-name }}")) | .id ') + echo "Current job id: $JOB_ID" + echo "job-id=$JOB_ID" >> "$GITHUB_OUTPUT" + echo "test_report_path=report_$JOB_ID.xml" >> "$GITHUB_OUTPUT" + - name: Git safe dir run: git config --global --add safe.directory ${{ steps.strings.outputs.work-dir }} @@ -411,31 +500,118 @@ jobs: shell: bash run: | source env/activate - pytest -ssv runtime/tools/python/test/test_read.py + pytest -ssv runtime/tools/python/test \ + --junit-xml=${{ steps.strings.outputs.test_report_path }} + + - name: Upload Test Report + uses: actions/upload-artifact@v4 + if: success() || failure() + with: + name: test-reports-${{ matrix.build.runs-on }}-${{ matrix.build.name }} + path: ${{ steps.strings.outputs.test_report_path }} + + - name: Show Test Report + uses: mikepenz/action-junit-report@v4 + if: success() || failure() + with: + report_paths: ${{ steps.strings.outputs.test_report_path }} + check_name: Run ttrt tests + + run-runtime-api-tests: - - name: ttrt query tests + timeout-minutes: 30 + needs: + - build-image + - build-ttmlir + strategy: + fail-fast: false + matrix: + build: [ + {runs-on: n150, enable_perf: OFF, name: "run"}, + ] + + runs-on: + - in-service + - ${{ matrix.build.runs-on }} + + container: + image: ${{ needs.build-image.outputs.docker-image }} + options: --device /dev/tenstorrent/0 + volumes: + - /dev/hugepages:/dev/hugepages + - /dev/hugepages-1G:/dev/hugepages-1G + - /etc/udev/rules.d:/etc/udev/rules.d + - /lib/modules:/lib/modules + - /opt/tt_metal_infra/provisioning/provisioning_env:/opt/tt_metal_infra/provisioning/provisioning_env + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set reusable strings + id: strings + shell: bash + run: | + echo "work-dir=$(pwd)" >> "$GITHUB_OUTPUT" + echo "build-output-dir=$(pwd)/build" >> "$GITHUB_OUTPUT" + echo "install-output-dir=$(pwd)/install" >> "$GITHUB_OUTPUT" + + - name: Git safe dir + run: git config --global --add safe.directory ${{ steps.strings.outputs.work-dir }} + + - name: Use build artifacts + uses: actions/download-artifact@v4 + with: + name: install-artifacts-${{ matrix.build.name }} + path: ${{ steps.strings.outputs.install-output-dir }} + + # This is needed to preserve file permissions + # https://github.com/actions/upload-artifact?tab=readme-ov-file#permission-loss + - name: 'Untar install directory' + shell: bash + working-directory: ${{ steps.strings.outputs.install-output-dir }} + run: tar xvf artifact.tar + + - name: Remove existing whls files + shell: bash + run: | + rm -f *.whl + + - name: Download ttrt run whls + uses: actions/download-artifact@v4 + with: + name: ttrt-whl-${{ matrix.build.name }} + + # Runtime tests currently require ttrt whls to be installed + - name: Install ttrt run whls shell: bash run: | source env/activate - pytest -ssv runtime/tools/python/test/test_query.py + pip show ttrt && pip uninstall -y ttrt + pip install ttrt-${{ env.version }}*.whl --force-reinstall + pip install pytest - - name: ttrt check tests + - name: Generate system descriptor shell: bash run: | source env/activate - pytest -ssv runtime/tools/python/test/test_check.py + ttrt query --save-artifacts - - name: ttrt run tests + - name: Generate tests shell: bash run: | source env/activate - pytest -ssv runtime/tools/python/test/test_run.py + export LD_LIBRARY_PATH="${TTMLIR_TOOLCHAIN_DIR}/lib:${LD_LIBRARY_PATH}" + export SYSTEM_DESC_PATH="${GITHUB_WORKSPACE}/ttrt-artifacts/system_desc.ttsys" + ln -sf ${{ steps.strings.outputs.install-output-dir }} ${{ steps.strings.outputs.build-output-dir }} + llvm-lit -sv ${{ steps.strings.outputs.build-output-dir }}/test - - name: ttrt perf tests + - name: ttnn api tests shell: bash run: | source env/activate - pytest -ssv runtime/tools/python/test/test_perf.py + pytest -ssv runtime/test/python/ttnn/test_runtime_api.py build-and-test-explorer: needs: build-image @@ -472,6 +648,7 @@ jobs: run: | echo "work-dir=$(pwd)" >> "$GITHUB_OUTPUT" echo "build-output-dir=$(pwd)/build" >> "$GITHUB_OUTPUT" + echo "install-output-dir=$(pwd)/install" >> "$GITHUB_OUTPUT" - name: Git safe dir run: git config --global --add safe.directory ${{ steps.strings.outputs.work-dir }} @@ -480,7 +657,7 @@ jobs: uses: hendrikmuhs/ccache-action@v1.2 with: create-symlink: true - key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-${{ env.SDK_VERSION }} + key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }}-${{ env.SDK_VERSION }} - name: Configure CMake shell: bash @@ -496,6 +673,7 @@ jobs: -DTTMLIR_ENABLE_RUNTIME_TESTS=OFF \ -DTT_RUNTIME_ENABLE_PERF_TRACE=${{ matrix.build.enable_perf }} \ -DTTMLIR_ENABLE_STABLEHLO=OFF \ + -DTTMLIR_ENABLE_OP_MODEL=${{ matrix.build.enable_op_model }} \ -S ${{ steps.strings.outputs.work-dir }} - name: Build tt-explorer @@ -509,3 +687,4 @@ jobs: run: | source env/activate pytest tools/explorer/test/run_tests.py + # collect results diff --git a/.github/workflows/issue-last-updated.yml b/.github/workflows/issue-last-updated.yml index 61a235aff2..f79d16c2c5 100644 --- a/.github/workflows/issue-last-updated.yml +++ b/.github/workflows/issue-last-updated.yml @@ -21,6 +21,7 @@ jobs: echo "project_id=PVT_kwDOA9MHEM4AjeTl" >> $GITHUB_ENV echo "field_id=PVTF_lADOA9MHEM4AjeTlzgiiU18" >> $GITHUB_ENV + - name: Get Issue ID id: get_issue_id run: | @@ -31,18 +32,94 @@ jobs: - name: Get Item ID for Issue - id: get_item_by_issue_id + id: get_item_id_by_issue_id run: | - ITEM_ID=$(curl -X POST -H "Authorization: Bearer $GITHUB_TOKEN" \ - -H "Content-Type: application/json" \ - -d '{ - "query": "query($projectId: ID!) { node(id: $projectId) { ... on ProjectV2 { items(first: 100) { nodes { id content { ... on Issue { id } } } } } } }", - "variables": { - "projectId": "'"${{ env.project_id }}"'" - } - }' \ - https://api.github.com/graphql | jq -r '.data.node.items.nodes[] | select(.content.id=="'"${{ env.issue_id }}"'") | .id') - echo "ITEM_ID=$ITEM_ID" >> $GITHUB_ENV + # Initialize variables + CURSOR=null + ITEM_ID="" + + + # Define the GraphQL query as a string + QUERY='query($projectId: ID!, $cursor: String) { + node(id: $projectId) { + ... on ProjectV2 { + items(first: 100, after: $cursor) { + nodes { + id + content { + ... on Issue { + id + } + } + } + pageInfo { + hasNextPage + endCursor + } + } + } + } + }' + + + while : ; do + # Construct JSON payload using jq for proper formatting + JSON_PAYLOAD=$(jq -n \ + --arg query "$QUERY" \ + --arg projectId "${{ env.project_id }}" \ + --arg cursor "$CURSOR" \ + '{ query: $query, variables: { projectId: $projectId, cursor: $cursor }}') + + + # Make the GraphQL request + RESPONSE=$(curl -s -X POST -H "Authorization: Bearer $GITHUB_TOKEN" \ + -H "Content-Type: application/json" \ + -d "$JSON_PAYLOAD" \ + https://api.github.com/graphql) + + + # Debug: print entire response + echo "RESPONSE: $RESPONSE" + + + # Check if the response contains `items` data + ITEMS_DATA=$(echo "$RESPONSE" | jq -r '.data.node.items.nodes' 2>/dev/null) + if [[ "$ITEMS_DATA" == "null" ]]; then + echo "Error: Items data not found. Please check your PROJECT_ID and GITHUB_TOKEN permissions." + break + fi + + + # Parse the item ID if it matches the issue_id + ITEM_ID=$(echo "$RESPONSE" | jq -r --arg issue_id "$issue_id" \ + '.data.node.items.nodes[] | select(.content.id==$issue_id) | .id') + + + # If ITEM_ID is found, output it and stop the loop + if [[ -n "$ITEM_ID" && "$ITEM_ID" != "null" ]]; then + echo "Found ITEM_ID: $ITEM_ID" + echo "ITEM_ID=$ITEM_ID" >> $GITHUB_ENV # Save ITEM_ID to environment for future steps + break + fi + + + # Extract pagination information + HAS_NEXT_PAGE=$(echo "$RESPONSE" | jq -r '.data.node.items.pageInfo.hasNextPage') + CURSOR=$(echo "$RESPONSE" | jq -r '.data.node.items.pageInfo.endCursor') + + + # If no more pages, exit loop + if [[ "$HAS_NEXT_PAGE" != "true" ]]; then + echo "Issue not found in project items." + break + fi + done + + + - name: Use Found ITEM_ID + if: env.ITEM_ID # Only runs if ITEM_ID was set + run: echo "The ITEM_ID is ${{ env.ITEM_ID }}" + - name: Update Project Field run: | diff --git a/.github/workflows/macos-build.yml b/.github/workflows/macos-build.yml index 774feed21c..367c15787e 100644 --- a/.github/workflows/macos-build.yml +++ b/.github/workflows/macos-build.yml @@ -1,8 +1,9 @@ name: Build on macos-latest on: - workflow_dispatch: - workflow_call: + schedule: + - cron: '0 4 * * *' # Runs at 04:00 UTC every day + workflow_dispatch: # Manual trigger env: SDK_VERSION: "0" diff --git a/.github/workflows/nightly-uplift.yml b/.github/workflows/nightly-uplift.yml index a0f6eb5345..54dd758aed 100644 --- a/.github/workflows/nightly-uplift.yml +++ b/.github/workflows/nightly-uplift.yml @@ -5,7 +5,7 @@ name: Nighty Uplift on: schedule: - - cron: '0 8 * * *' # Runs at 08:00 UTC every day + - cron: '0 6 * * *' # Runs at 06:00 UTC every day workflow_dispatch: # Manual trigger jobs: @@ -13,25 +13,30 @@ jobs: runs-on: ubuntu-latest env: - SUBMODULE_PATH: third_party/tt-metal - TT_METAL_VERSION: origin/main + TT_METAL_SUBMODULE_PATH: third_party/tt-metal steps: - - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 0 + ref: main - - name: Set env variable + - name: Set env variable for today's date run: | echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - - name: Update tt-metal reference + - name: Fetch latest SHA of tt-metal submodule env: GH_TOKEN: ${{ github.token }} run: | - # Fetch the latest SHA using GitHub CLI - LATEST_SHA=$(gh api repos/tenstorrent/tt-metal/commits/main --jq '.sha') - # Update the third_party/CMakeLists.txt file with the new SHA - sed -i "s/set(TT_METAL_VERSION \".*\")/set(TT_METAL_VERSION \"${LATEST_SHA}\")/" third_party/CMakeLists.txt + LATEST_TT_METAL_VERSION=$(gh api repos/tenstorrent/tt-metal/commits/main --jq '.sha') + echo "LATEST_TT_METAL_VERSION=$LATEST_TT_METAL_VERSION" >> $GITHUB_ENV + + - name: Update tt-metal reference in third_party/CMakeLists.txt + run: | + echo "Updating tt-metal to SHA: ${{ env.LATEST_TT_METAL_VERSION }}" + sed -i "s/set(TT_METAL_VERSION \".*\")/set(TT_METAL_VERSION \"${{ env.LATEST_TT_METAL_VERSION }}\")/" third_party/CMakeLists.txt - name: Create Pull Request uses: peter-evans/create-pull-request@v7 @@ -41,9 +46,9 @@ jobs: committer: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> author: ${{ github.actor }} <${{ github.actor_id }}+${{ github.actor }}@users.noreply.github.com> base: main - commit-message: "Uplift ${{ env.SUBMODULE_PATH }} to ${{ env.SUBMODULE_VERSION }} ${{ env.TODAY }}" - title: "Uplift ${{ env.SUBMODULE_PATH }} to ${{ env.SUBMODULE_VERSION }} ${{ env.TODAY }}" - body: "This PR uplifts the ${{ env.SUBMODULE_PATH }} to the ${{ env.SUBMODULE_VERSION }}" + commit-message: "Uplift ${{ env.TT_METAL_SUBMODULE_PATH }} to ${{ env.LATEST_TT_METAL_VERSION }} ${{ env.TODAY }}" + title: "Uplift ${{ env.TT_METAL_SUBMODULE_PATH }} to ${{ env.LATEST_TT_METAL_VERSION }} ${{ env.TODAY }}" + body: "This PR uplifts the ${{ env.TT_METAL_SUBMODULE_PATH }} to the ${{ env.LATEST_TT_METAL_VERSION }}" labels: uplift delete-branch: true token: ${{ secrets.GH_TOKEN }} @@ -57,8 +62,11 @@ jobs: echo "Pull Request URL - ${{ steps.create-pr.outputs.pull-request-url }}" gh pr review ${{ steps.create-pr.outputs.pull-request-number }} --approve - - name: Enable Pull Request Automerge - if: ${{ steps.create-pr.outputs.pull-request-number }} - run: gh pr merge --squash --auto "${{ steps.create-pr.outputs.pull-request-number }}" - env: - GH_TOKEN: ${{ secrets.GH_TOKEN }} + # Note: Dissable auto-merge for now until we are more confident + # that uplift won't break the downstream projects + # + # - name: Enable Pull Request Automerge + # if: ${{ steps.create-pr.outputs.pull-request-number }} + # run: gh pr merge --squash --auto "${{ steps.create-pr.outputs.pull-request-number }}" + # env: + # GH_TOKEN: ${{ secrets.GH_TOKEN }} diff --git a/.github/workflows/on-pr.yml b/.github/workflows/on-pr.yml index 76999f97df..76a781f886 100644 --- a/.github/workflows/on-pr.yml +++ b/.github/workflows/on-pr.yml @@ -12,9 +12,6 @@ jobs: spdx: uses: ./.github/workflows/spdx.yml secrets: inherit - macos-build: - uses: ./.github/workflows/macos-build.yml - secrets: inherit build-and-test: uses: ./.github/workflows/build-and-test.yml secrets: inherit diff --git a/.github/workflows/on-push.yml b/.github/workflows/on-push.yml index 58dcdc65d6..2d961e2204 100644 --- a/.github/workflows/on-push.yml +++ b/.github/workflows/on-push.yml @@ -12,9 +12,6 @@ jobs: spdx: uses: ./.github/workflows/spdx.yml secrets: inherit - macos-build: - uses: ./.github/workflows/macos-build.yml - secrets: inherit build-and-test: uses: ./.github/workflows/build-and-test.yml secrets: inherit diff --git a/.github/workflows/produce_data.yml b/.github/workflows/produce_data.yml new file mode 100644 index 0000000000..e53ccc0f60 --- /dev/null +++ b/.github/workflows/produce_data.yml @@ -0,0 +1,28 @@ +name: "[internal] Collect workflow data" + +on: + workflow_run: + workflows: # List workflow that we want to collect data for + - "On PR" + - "On push" + - "Build on macos-latest" + - "Build and Test" + types: + - completed + +jobs: + produce-cicd-data: + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + steps: + - name: Collect CI/CD data + uses: tenstorrent/tt-github-actions/.github/actions/collect_data@main + if: ${{ github.event_name == 'workflow_run' }} + with: + repository: ${{ github.repository }} + run_id: ${{ github.event.workflow_run.id }} + run_attempt: ${{ github.event.workflow_run.run_attempt }} + sftp_host: ${{ secrets.SFTP_CICD_WRITER_HOSTNAME }} + sftp_user: ${{ secrets.SFTP_CICD_WRITER_USERNAME }} + ssh-private-key: ${{ secrets.SFTP_CICD_WRITER_KEY }} diff --git a/.gitignore b/.gitignore index 8663a2ff0e..274c39c1f4 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,8 @@ ttrt-artifacts/* query_results.json run_results.json ttrt_report.xml +cluster_descriptor.yaml + +# TTNN and TTMetal flatbuffers +*.ttnn +*.ttm diff --git a/CMakeLists.txt b/CMakeLists.txt index 54fcc89d47..bebff7a0fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,7 @@ endif() option(TT_RUNTIME_ENABLE_PERF_TRACE "Enable performance mode" OFF) option(TTMLIR_ENABLE_RUNTIME "Enable runtime" OFF) option(TTMLIR_ENABLE_STABLEHLO "Enable StableHLO support" OFF) +option(TTMLIR_ENABLE_OP_MODEL "Enable OpModel support" OFF) if (TTMLIR_ENABLE_STABLEHLO) add_compile_definitions(TTMLIR_ENABLE_STABLEHLO) @@ -20,13 +21,22 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(TTMLIR_ENABLE_BINDINGS_PYTHON ON CACHE BOOL "Enable Python bindings") +if (APPLE) + set(TTMLIR_ENABLE_OP_MODEL OFF) + message(WARNING "TTNNOpModelLib is disabled on Apple platforms. Optimizer will not get true performance.") +endif() + list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/cmake/modules) if (TT_RUNTIME_ENABLE_PERF_TRACE) add_compile_options(-DTRACY_ENABLE=ON) endif() -add_compile_options(-Wall -Wextra -Wpedantic -Werror -Wno-unused-parameter --system-header-prefix=/opt/ttmlir-toolchain) +if (NOT DEFINED ENV{TTMLIR_TOOLCHAIN_DIR}) + message(FATAL_ERROR "TTMLIR_TOOLCHAIN_DIR environment variable not set. Please run 'source env/activate'.") +endif() + +add_compile_options(-Wall -Wextra -Wpedantic -Werror -Wno-unused-parameter --system-header-prefix=ENV{TTMLIR_TOOLCHAIN_DIR}) include(TTMLIRBuildTypes) @@ -40,10 +50,6 @@ set(Python3_EXECUTABLE $ENV{TTMLIR_VENV_DIR}/bin/python3) include(FindMLIR) include(TTMLIRVersion) -if (NOT DEFINED ENV{TTMLIR_TOOLCHAIN_DIR}) - message(FATAL_ERROR "TTMLIR_TOOLCHAIN_DIR environment variable not set. Please run 'source env/activate'.") -endif() - set(TTMLIR_TOOLCHAIN_DIR $ENV{TTMLIR_TOOLCHAIN_DIR}) set(TTMLIR_SOURCE_DIR ${PROJECT_SOURCE_DIR}) set(TTMLIR_BINARY_DIR ${PROJECT_BINARY_DIR}) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..e39ed6a281 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,22 @@ +# Contributing guidelines for TT-Forge + +## PR Guidelines +### Community contributions +Thank you for your interest in the TT-Forge project we appreciate your support. +For all PRs we have an internal policy listed below which your PR will go through after an initial review has been done. + +The initial review will encompase the following: +* Review the PR for CI / CD Readiness. Includes making sure that the code and PR at a high level makes sense for the project +* Once approved for CI / CD readiness a Tenstorrent developer will kick off our CI/CD pipeline on your behalf. + +### Internal contributions +For internal contributions we have the following guidelines: + +* A 24 hour merge rule exists. The rule is to wait at least 24 hours since the PR was initially opened for review. This gives members of our teams that span the globe opportunity to provide feedback to PRs. + +In addition to the 24 hour rule the following prerequisites for landing PR exist: +* At least 1 reviewer signs off on the change +* Component owner sign offs (github will tell you if this hasn't been met) +* Green CI +* Wait at least 24 hours since opening the PR to give all tagged reviewers a chance to take a look. Or at least comment on the issue that they need more time to review. + * *Rebasing or further changes to the PR do not reset the 24 hour counter.* diff --git a/cmake/modules/LintTools.cmake b/cmake/modules/LintTools.cmake index 28b4b28092..7e56040110 100644 --- a/cmake/modules/LintTools.cmake +++ b/cmake/modules/LintTools.cmake @@ -1,4 +1,12 @@ # clang-tidy setup add_custom_target(clang-tidy-filter-out-external-srcs COMMAND python3 ${TTMLIR_SOURCE_DIR}/tools/scripts/filter-compile-commands.py ${TTMLIR_BINARY_DIR}/compile_commands.json "${TTMLIR_SOURCE_DIR}") -add_custom_target(clang-tidy COMMAND run-clang-tidy.py -p ${PROJECT_BINARY_DIR} -export-fixes clang-tidy-fixes.yaml -warnings-as-errors '*' -extra-arg-before=-DDISABLE_STATIC_ASSERT_TESTS -extra-arg-before=-D__cpp_structured_bindings=202400 DEPENDS clang-tidy-filter-out-external-srcs) +add_custom_target(clang-tidy COMMAND run-clang-tidy.py -p ${PROJECT_BINARY_DIR} -export-fixes clang-tidy-fixes.yaml -warnings-as-errors '*' -extra-arg-before=-DDISABLE_STATIC_ASSERT_TESTS -extra-arg-before=-D__cpp_structured_bindings=202400 + DEPENDS + clang-tidy-filter-out-external-srcs + mlir-headers + mlir-generic-headers + tt-metal-download + tt-metal-configure + FBS_GENERATION +) add_custom_target(clang-format COMMAND git-clang-format) diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index beeb35883a..41ca83528c 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -5,7 +5,7 @@ # User Guide - [Building](./build.md) - - [Internal Build Notes / IRD](./internal-build.md) + - [Docker Notes](./docker-notes.md) - [Tools](./tools.md) - [ttmlir-opt](./ttmlir-opt.md) - [ttmlir-translate](./ttmlir-translate.md) diff --git a/docs/src/dialects-overview.md b/docs/src/dialects-overview.md index e886fb90c1..0dbf5fbed1 100644 --- a/docs/src/dialects-overview.md +++ b/docs/src/dialects-overview.md @@ -3,7 +3,7 @@ Here is a brief overview of the dialects in the project, please refer to the individual dialect documentation for more details.: -- `tt`: Common types such as, `tt.tile`, `tt.layout`, `tt.grid`, etc. and enums such as, data formats, memory spaces, iterator types etc. +- `tt`: Common types such as, `tt.tile`, `tt.metal_layout`, `tt.grid`, etc. and enums such as, data formats, memory spaces, iterator types etc. - `ttir`: A high level dialect that models the tensor compute graph on tenstorrent devices. Accepts `tosa` and `linalg` input. - `ttir.generic`: Generically describe compute work. - `ttir.to_layout`: Convert between different tensor memory layouts and transfer between different memory spaces. diff --git a/docs/src/internal-build.md b/docs/src/docker-notes.md similarity index 72% rename from docs/src/internal-build.md rename to docs/src/docker-notes.md index 11d2fb8642..1674bf2efc 100644 --- a/docs/src/internal-build.md +++ b/docs/src/docker-notes.md @@ -1,21 +1,11 @@ -# Internal Build Notes / IRD - -- When building the runtime we must use Ubuntu 22.04 docker image - - When making an IRD reservation use `--docker-image - yyz-gitlab.local.tenstorrent.com:5005/tenstorrent/infra/ird-ubuntu-22-04-amd64:latest` -- You'll have to manaully install a newer version of cmake, at least 3.22, the easiest way to do this is to `pip install cmake` and make sure this one is in your path -- You'll want LLVM installation to persist IRD reservations, you can achieve this by: - - mkdir /localdev/$USER/ttmlir-toolchain - - When requesting an IRD use `--volumes /localdev/$USER/ttmlir-toolchain:/opt/ttmlir-toolchain` - -## Working with Docker Images +# Working with Docker Images Components: - Dockerfile - Workflow for building Docker image - Project build using Docker image -### Overview +## Overview We use docker images to prepare project enviroment, install dependancies, tooling and prebuild toolchain. Project builds four docker images: @@ -29,11 +19,11 @@ Base image starts with a supported base image (Ubuntu 22.04) and installs depend During the CI Docker build, the project is built and tests are run to ensure that everything is set up correctly. If any dependencies are missing, the Docker build will fail. -### Building the Docker Image using GitHub Actions +## Building the Docker Image using GitHub Actions The GitHub Actions workflow [Build and Publish Docker Image](.github/workflows/build-image.yml) builds the Docker images and uploads them to GitHub Packages at https://github.com/orgs/tenstorrent/packages?repo_name=tt-mlir. We use the git SHA we build from as the tag. -### Building the Docker Image Locally +## Building the Docker Image Locally To test the changes and build the image locally, use the following command: ```bash @@ -43,7 +33,7 @@ docker build -f .github/Dockerfile.ird -build-args FROM_IMAGE=base -t ghcr.io/te docker build -f .github/Dockerfile.ird -build-args FROM_IMAGE=ci -t ghcr.io/tenstorrent/tt-mlir/tt-mlir-ird-ubuntu-22-04:latest . ``` -### Using the Image in GitHub Actions Jobs +## Using the Image in GitHub Actions Jobs The GitHub Actions workflow [Build in Docker](.github/workflows/docker-build.yml) uses a Docker container for building: ```yaml diff --git a/docs/src/specs/device.md b/docs/src/specs/device.md index ae72fe638c..64bc91cfa9 100644 --- a/docs/src/specs/device.md +++ b/docs/src/specs/device.md @@ -135,7 +135,7 @@ the logical device grid: ```mlir tensor<16x3x64x128xf32, - #tt.layout<(d0, d1, d2, d3) -> (d0, d1 * 64 + d2, d3), + #tt.metal_layout<(d0, d1, d2, d3) -> (d0, d1 * 64 + d2, d3), undef, <2x2x4>, memref<8x3x1x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space> @@ -170,7 +170,7 @@ the logical device grid: ```mlir tensor<256x1024xf32, - #tt.layout<(d0, d1) -> (d0, d1), + #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x16>, memref<2x2x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space> @@ -205,7 +205,7 @@ We can consider the following tensor to map onto this grid: ```mlir tensor<64x256x1024xf32, - #tt.layout<(d0, d1) -> (d0, d1), + #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x4x16>, memref<32x2x2x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space> diff --git a/docs/src/specs/tensor-layout.md b/docs/src/specs/tensor-layout.md index d523f51ed2..52c6931895 100644 --- a/docs/src/specs/tensor-layout.md +++ b/docs/src/specs/tensor-layout.md @@ -33,7 +33,7 @@ been used by the TT dialect to encode the tensor's layout. This looks like: ```mlir tensor<2x3x64x128xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, @@ -76,7 +76,7 @@ topics: ### Dimension Collapsing -Probably the most important concept in `tt.layout` is dimension collapsing. +Probably the most important concept in `tt.metal_layout` is dimension collapsing. This is captured by the affine map `linear` property which provides a mapping from tensor dim space to a reduced physical dimensional space. This single-handedly touches on most of the tensor layout goals mentioned at the @@ -106,7 +106,7 @@ to get our remapped offset: This remapped offset `(262, 100)` corresponds to the row and column index of the collapsed physical memory. -By default, the dim range `[0, -1)` is collapsed, but the `tt.layout` contructor +By default, the dim range `[0, -1)` is collapsed, but the `tt.metal_layout` contructor can actually take a programmable range called `collapseIntervals`. `collapseIntervals` is a list of pairs, where each pair is a dim range interval, left inclusive, right exclusive. Let's consider a few examples: @@ -137,7 +137,7 @@ Let's consider the original example again, but on a larger grid than `1x1`, say ```mlir tensor<2x3x64x128xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, @@ -173,7 +173,7 @@ Here's a few more example mlir snippets: ```mlir tensor<8x300xf32, - #tt.layout<(d0, d1) -> (d0, d1), + #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x2>, memref<8x150xf32, #tt.memory_space> @@ -181,7 +181,7 @@ tensor<8x300xf32, > tensor<8x96x32xf32, - #tt.layout<(d0, d1, d2) -> (d0 * 96 + d1, d2), + #tt.metal_layout<(d0, d1, d2) -> (d0 * 96 + d1, d2), undef, <2x1>, memref<384x32xf32, #tt.memory_space> @@ -189,7 +189,7 @@ tensor<8x96x32xf32, > tensor<8x96x32xf32, - #tt.layout<(d0, d1, d2) -> (d0 * 96 + d1, d1, d2), + #tt.metal_layout<(d0, d1, d2) -> (d0 * 96 + d1, d1, d2), undef, <2x1x2>, memref<384x96x16xf32, #tt.memory_space> @@ -197,7 +197,7 @@ tensor<8x96x32xf32, > tensor<5x3x2x2x7x32x32xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2, d3, d4, d5, d6) -> (d0 * 2688 + d1 * 896 + d2 * 448 + d3 * 224 + d4 * 32 + d5, d4, d5, d6), undef, @@ -226,7 +226,7 @@ A tilized tensor is one with a memref that has a tile element type. Given some tensor with scalar layout: ```mlir tensor<3x64x128xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2) -> (d0 * 64 + d1, d2), undef, <3x2>, @@ -238,7 +238,7 @@ tensor<3x64x128xf32, After tilizing we'll have: ```mlir tensor<3x64x128xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2) -> (d0 * 64 + d1, d2), undef, <3x2>, @@ -256,7 +256,7 @@ intact. Padding can be a bit of an overloaded term, but in this context it refers to an out of bounds area in the physical memory allocation that has no real tensor data in it. The contents of this area is tracked by `oob_val` and the padding -area can be automatically derived from the attributes of `tt.layout`. +area can be automatically derived from the attributes of `tt.metal_layout`. Padding is a necessary evil that arises when a tensor is not evenly divisible by a grid shape or tile shape. It can also arise due to minimum Noc addressing @@ -265,7 +265,7 @@ requirements. Example of non-divisible grid: ```mlir tensor<53x63xf32, - #tt.layout< + #tt.metal_layout< (d0, d1) -> (d0, d1), undef, <3x2>, @@ -284,7 +284,7 @@ cores and 1 scalar column of padding on the last column of cores. Taking the above example a step further, we could tilize it: ```mlir tensor<53x63xf32, - #tt.layout< + #tt.metal_layout< (d0, d1) -> (d0, d1), undef, <3x2>, @@ -308,7 +308,7 @@ stride between dimensions. Consider tensor (w/ batch dim `2`): ```mlir tensor<2x8x32xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2) -> (d0 * 8 + d1, d2), undef, <1x2>, @@ -356,7 +356,7 @@ consider the following example with a 3d grid and `collapseIntervals=[(1, -1)]`. ```mlir tensor<2x3x64x128xf32, - #tt.layout<(d0, d1, d2, d3) -> (d0, d1 * 64 + d2, d3), + #tt.metal_layout<(d0, d1, d2, d3) -> (d0, d1 * 64 + d2, d3), undef, <2x2x4>, memref<1x3x1x!tt.tile<32 x 32, bfp_bf8>, #tt.memory_space> @@ -387,7 +387,7 @@ under the same grid primitive that also divides tensor rows and columns. ## Concerns -- `tt.layout` is deliberately flexible and tries to capture as many problematic +- `tt.metal_layout` is deliberately flexible and tries to capture as many problematic use-cases we've ran into in the past in a single, succinct representation. This flexibility will need to be further constrained by backends to avoid unsupported programming of this attribute. diff --git a/docs/src/ttmlir-translate.md b/docs/src/ttmlir-translate.md index c82f7ee8f0..ba9c69b3c5 100644 --- a/docs/src/ttmlir-translate.md +++ b/docs/src/ttmlir-translate.md @@ -5,15 +5,15 @@ The `ttmlir-translate` translation utility. Unlike `ttmlir-opt` tool which is us ```bash # First, let's run `ttmlir-opt` to convert to proper dialect -./build/bin/ttmlir-opt --ttir-load-system-desc --ttir-layout --convert-ttir-to-ttnn --convert-ttnn-to-emitc test/ttmlir/Dialect/TTNN/simple_multiply.mlir -o c.mlir +./build/bin/ttmlir-opt --ttir-to-emitc-pipeline test/ttmlir/Dialect/TTNN/simple_multiply.mlir -o c.mlir # Now run `ttmlir-translate` to produce C++ code -./build/bin/ttmlir-translate -mlir-to-cpp c.mlir -allow-unregistered-dialect +./build/bin/ttmlir-translate --mlir-to-cpp c.mlir ``` Bonus: These two commands can be piped, to avoid writing a `mlir` file to disk, like so: ```bash -./build/bin/ttmlir-opt --ttir-load-system-desc --ttir-layout --convert-ttir-to-ttnn --convert-ttnn-to-emitc test/ttmlir/Dialect/TTNN/simple_multiply.mlir | ./build/bin/ttmlir-translate -mlir-to-cpp -allow-unregistered-dialect +./build/bin/ttmlir-opt --ttir-to-emitc-pipeline test/ttmlir/Dialect/TTNN/simple_multiply.mlir | ./build/bin/ttmlir-translate -mlir-to-cpp ``` ## Generate flatbuffer file from MLIR diff --git a/env/CMakeLists.txt b/env/CMakeLists.txt index 0f3c26736b..1e26e8ca43 100644 --- a/env/CMakeLists.txt +++ b/env/CMakeLists.txt @@ -54,7 +54,7 @@ ExternalProject_Add( -DLLVM_INSTALL_GTEST=ON -DLLVM_LINK_LLVM_DYLIB=ON -DMLIR_BUILD_MLIR_C_DYLIB=ON - -DMLIR_LINK_MLIR_DYLIB=ON + -DMLIR_LINK_MLIR_DYLIB=OFF -DMLIR_BUILD_MLIR_C_DYLIB=ON # ====================== -DCMAKE_BUILD_TYPE=MinSizeRel diff --git a/include/ttmlir-c/TTAttrs.h b/include/ttmlir-c/TTAttrs.h index fbbe8de4bd..263cd1d8e4 100644 --- a/include/ttmlir-c/TTAttrs.h +++ b/include/ttmlir-c/TTAttrs.h @@ -50,9 +50,9 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTSystemDescAttrGet( size_t chipCoordsSize, MlirAttribute *chipChannels, size_t chipChannelsSize); -MLIR_CAPI_EXPORTED MlirAttribute -ttmlirTTLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, unsigned oobVal, - MlirAttribute grid, MlirType memref, unsigned memLayout); +MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTMetalLayoutAttrGet( + MlirContext ctx, MlirAffineMap linear, unsigned oobVal, MlirAttribute grid, + MlirType memref, unsigned memLayout); MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTMemorySpaceAttrGet(MlirContext ctx, uint32_t memorySpace); @@ -84,6 +84,9 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipPhysicalCoresAttrGet( MlirAttribute *dram, size_t dramSize, MlirAttribute *eth, size_t ethSize, MlirAttribute *eth_inactive, size_t eth_inactiveSize); +MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTCoreCoordAttrGet(MlirContext ctx, + int64_t y, int64_t x); + #ifdef __cplusplus } #endif diff --git a/include/ttmlir-c/TTNNAttrs.h b/include/ttmlir-c/TTNNAttrs.h index a7f5a8170d..ea3e333c2d 100644 --- a/include/ttmlir-c/TTNNAttrs.h +++ b/include/ttmlir-c/TTNNAttrs.h @@ -5,6 +5,7 @@ #ifndef TTMLIR_C_TTNNATTRS_H #define TTMLIR_C_TTNNATTRS_H +#include "mlir-c/AffineMap.h" #include "ttmlir-c/Dialects.h" #ifdef __cplusplus @@ -44,6 +45,10 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNMeshShapeAttrGet(MlirContext ctx, int64_t y, int64_t x); +MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNTTNNLayoutAttrGet( + MlirContext ctx, MlirAffineMap linear, MlirAttribute grid, MlirType memref, + unsigned memLayout); + #ifdef __cplusplus } #endif diff --git a/include/ttmlir/Bindings/Python/TTMLIRModule.h b/include/ttmlir/Bindings/Python/TTMLIRModule.h index 5f2d4e134d..d36529e676 100644 --- a/include/ttmlir/Bindings/Python/TTMLIRModule.h +++ b/include/ttmlir/Bindings/Python/TTMLIRModule.h @@ -60,6 +60,7 @@ void populateTTIRModule(py::module &m); void populateTTKernelModule(py::module &m); void populateTTNNModule(py::module &m); void populateOverridesModule(py::module &m); +void populateOptimizerOverridesModule(py::module &m); void populatePassesModule(py::module &m); } // namespace mlir::ttmlir::python diff --git a/include/ttmlir/Conversion/CMakeLists.txt b/include/ttmlir/Conversion/CMakeLists.txt index 891fa56080..ba6b267836 100644 --- a/include/ttmlir/Conversion/CMakeLists.txt +++ b/include/ttmlir/Conversion/CMakeLists.txt @@ -5,7 +5,9 @@ include_directories(${TTMLIR_SOURCE_DIR}/include) set(LLVM_TARGET_DEFINITIONS Passes.td) if (TTMLIR_ENABLE_STABLEHLO) mlir_tablegen(Passes.h.inc -gen-pass-decls -name TTMLIRConversion -DTTMLIR_ENABLE_STABLEHLO) +add_dependencies(mlir-headers PassesIncGen) else() mlir_tablegen(Passes.h.inc -gen-pass-decls -name TTMLIRConversion) endif() add_public_tablegen_target(TTMLIRConversionPassIncGen) +add_dependencies(mlir-headers TTMLIRConversionPassIncGen) diff --git a/include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h b/include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h index acd5373c90..5f1feb08b2 100644 --- a/include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h +++ b/include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h @@ -7,11 +7,15 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir::tt { +void populateTosaToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter); + std::unique_ptr> createConvertTosaToTTIRPass(); } // namespace mlir::tt -#endif +#endif // TTMLIR_CONVERSION_TOSATOTTIR_TOSATOTTIR_H diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td b/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td index b82c71c3f6..aee19f63c6 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td @@ -137,6 +137,7 @@ def TT_OperandConstraintSingleBank : I32BitEnumAttrCaseBit<"SingleBank", 7, "sin def TT_OperandConstraintHeightSharded : I32BitEnumAttrCaseBit<"HeightSharded", 8, "height_sharded">; def TT_OperandConstraintWidthSharded : I32BitEnumAttrCaseBit<"WidthSharded", 9, "width_sharded">; def TT_OperandConstraintBlockSharded : I32BitEnumAttrCaseBit<"BlockSharded", 10, "block_sharded">; +def TT_OperandConstraintSystemScalar : I32BitEnumAttrCaseGroup<"SystemScalar", [TT_OperandConstraintSystem, TT_OperandConstraintScalar], "system_scalar">; def TT_OperandConstraintAnyLayout : I32BitEnumAttrCaseGroup<"AnyLayout", [TT_OperandConstraintNone, TT_OperandConstraintInterleaved, TT_OperandConstraintSingleBank, TT_OperandConstraintHeightSharded, TT_OperandConstraintWidthSharded, TT_OperandConstraintBlockSharded], "any_layout">; def TT_OperandConstraintAny : I32BitEnumAttrCaseGroup<"Any", [TT_OperandConstraintSystem, TT_OperandConstraintDRAM, TT_OperandConstraintL1, TT_OperandConstraintScalar, TT_OperandConstraintTile, TT_OperandConstraintAnyLayout], "any">; def TT_OperandConstraintAnyDevice : I32BitEnumAttrCaseGroup<"AnyDevice", [TT_OperandConstraintDRAM, TT_OperandConstraintL1, TT_OperandConstraintScalar, TT_OperandConstraintTile, TT_OperandConstraintAnyLayout], "any_device">; @@ -155,6 +156,7 @@ def TT_OperandConstraint : I32BitEnumAttr<"OperandConstraint", "TT Operand Const TT_OperandConstraintHeightSharded, TT_OperandConstraintWidthSharded, TT_OperandConstraintBlockSharded, + TT_OperandConstraintSystemScalar, TT_OperandConstraintAnyLayout, TT_OperandConstraintAny, TT_OperandConstraintAnyDevice, @@ -189,6 +191,54 @@ def TT_BufferAccess : I32BitEnumAttr<"BufferAccess", "TT Buffer Access", let cppNamespace = "::mlir::tt"; } +def TT_ReduceType_Sum : I32EnumAttrCase<"Sum", 0, "sum">; +def TT_ReduceType_Mean : I32EnumAttrCase<"Mean", 1, "mean">; +def TT_ReduceType_Max : I32EnumAttrCase<"Max", 2, "max">; +def TT_ReduceType_Min : I32EnumAttrCase<"Min", 3, "min">; +def TT_ReduceType_Std : I32EnumAttrCase<"Std", 4, "std">; +def TT_ReduceType_Var : I32EnumAttrCase<"Var", 5, "var">; + +def TT_ReduceType: I32EnumAttr<"ReduceType", "TT Reduce Type", + [ + TT_ReduceType_Sum, + TT_ReduceType_Mean, + TT_ReduceType_Max, + TT_ReduceType_Min, + TT_ReduceType_Std, + TT_ReduceType_Var, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tt"; +} + +def TT_MeshShardDirection_FullToShard : I32EnumAttrCase<"FullToShard", 0, "full_to_shard">; +def TT_MeshShardDirection_ShardToFull : I32EnumAttrCase<"ShardToFull", 1, "shard_to_full">; + +def TT_MeshShardDirection: I32EnumAttr<"MeshShardDirection", "TT MeshShardDirection", + [ + TT_MeshShardDirection_FullToShard, + TT_MeshShardDirection_ShardToFull, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tt"; +} + +def TT_MeshShardType_Manual : I32EnumAttrCase<"Manual", 0, "manual">; +def TT_MeshShardType_Replicate : I32EnumAttrCase<"Replicate", 1, "replicate">; +def TT_MeshShardType_Maximal : I32EnumAttrCase<"Maximal", 2, "maximal">; +def TT_MeshShardType_Devices : I32EnumAttrCase<"Devices", 3, "devices">; + +def TT_MeshShardType: I32EnumAttr<"MeshShardType", "TT MeshShardType", + [ + TT_MeshShardType_Manual, + TT_MeshShardType_Replicate, + TT_MeshShardType_Maximal, + TT_MeshShardType_Devices, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tt"; +} + def TT_CPURoleHost : I32EnumAttrCase<"Host", 0, "host">; def TT_CPURoleDevice : I32EnumAttrCase<"Device", 1, "device">; diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td index d9ff13164e..d5dc22e28d 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td @@ -214,7 +214,7 @@ def TT_SystemDescAttr : TT_Attr<"SystemDesc", "system_desc"> { }]; } -def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { +def TT_MetalLayoutAttr : TT_Attr<"MetalLayout", "metal_layout"> { let summary = "Tensor layout attribute"; let description = [{ The tensor layout attribute captures how tensor data is sharded across a grid of devices, cores, and @@ -241,7 +241,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { Examples: ```mlir tensor<8x300xf32, - #tt.layout<(d0, d1) -> (d0, d1), + #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x2>, memref<8x150xf32, #tt.memory_space> @@ -249,7 +249,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { > tensor<8x96x32xf32, - #tt.layout<(d0, d1, d2) -> (d0 * 96 + d1, d2), + #tt.metal_layout<(d0, d1, d2) -> (d0 * 96 + d1, d2), undef, <2x1>, memref<384x32xf32, #tt.memory_space> @@ -257,7 +257,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { > tensor<8x96x32xf32, - #tt.layout<(d0, d1, d2) -> (d0 * 96 + d1, d1, d2), + #tt.metal_layout<(d0, d1, d2) -> (d0 * 96 + d1, d1, d2), undef, <2x1x2>, memref<384x96x16xf32, #tt.memory_space> @@ -265,7 +265,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { > tensor<5x3x2x2x7x32x32xf32, - #tt.layout< + #tt.metal_layout< (d0, d1, d2, d3, d4, d5, d6) -> (d0 * 2688 + d1 * 896 + d2 * 448 + d3 * 224 + d4 * 32 + d5, d4, d5, d6), undef, @@ -284,7 +284,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { let assemblyFormat = "`<` $linear`,` $oob_val`,` $grid`,` $memref (`,` $mem_layout^)? `>`"; let extraClassDeclaration = [{ - static LayoutAttr get(::mlir::MLIRContext *context, + static MetalLayoutAttr get(::mlir::MLIRContext *context, ArrayRef tensorShape, Type elementType, MemorySpace memorySpace = MemorySpace::System, @@ -292,28 +292,28 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { ArrayRef> collapseIntervals = {{0, -1}}, OOBVal oobVal = OOBVal::Undef, TensorMemoryLayout memLayout = TensorMemoryLayout::None); - static LayoutAttr get(::mlir::MLIRContext *context, + static MetalLayoutAttr get(::mlir::MLIRContext *context, RankedTensorType ty, MemorySpace memorySpace = MemorySpace::System, GridAttr grid = {}, ArrayRef> collapseIntervals = {{0, -1}}, OOBVal oobVal = OOBVal::Undef, TensorMemoryLayout memLayout = TensorMemoryLayout::None); - static LayoutAttr get(::mlir::MLIRContext *context, + static MetalLayoutAttr get(::mlir::MLIRContext *context, RankedTensorType ty, MemorySpace memorySpace, GridAttr grid, Type elementType, TensorMemoryLayout memLayout = TensorMemoryLayout::None); - LayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); - LayoutAttr withGrid(::mlir::MLIRContext *context, + MetalLayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); + MetalLayoutAttr withGrid(::mlir::MLIRContext *context, RankedTensorType ty, GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); - LayoutAttr withElementType(::mlir::MLIRContext *context, Type elementType); - LayoutAttr withMemorySpace(::mlir::MLIRContext *context, MemorySpace memorySpace); - LayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout); - LayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector shardShape); + MetalLayoutAttr withElementType(::mlir::MLIRContext *context, Type elementType); + MetalLayoutAttr withMemorySpace(::mlir::MLIRContext *context, MemorySpace memorySpace); + MetalLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout); + MetalLayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector shardShape); uint64_t getMemrefSizeBytes() const; MemorySpace getMemorySpace() const; @@ -400,7 +400,7 @@ def TT_DeviceAttr : TT_Attr<"Device", "device", []> { // - DeviceL1: This ends up being exactly the shard size // - DeviceDRAM: Is more nuanced because the whole tensor size gets paged and interleaved between all dram channels, // due to paging and rounding the footprint ends up being close to: the_whole_tensor / num_dram_channels - uint64_t getLayoutSizeBytes(ArrayRef tensorShape, LayoutAttr layout, MemorySpace memorySpace) const; + uint64_t getLayoutSizeBytes(ArrayRef tensorShape, MetalLayoutAttr layout, MemorySpace memorySpace) const; // Returns the footprint size in bytes of the tensor distributed across the given memory space. // Forwards to getLayoutSizeBytes, see comment there for more info. @@ -443,6 +443,20 @@ def TT_ArgumentAllocationAttr : TT_Attr<"ArgumentAllocation", "arg_alloc", []> { let assemblyFormat = "`<` $address `,` $size `,` $memorySpace `>`"; } +def TT_ReduceTypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TT_ReduceTypeArrayAttr : TypedArrayAttrBase; + +def TT_MeshShardDirectionAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TT_MeshShardTypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + //===----------------------------------------------------------------------===// // TT type definitions //===----------------------------------------------------------------------===// diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td b/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td index 57f3dc37d3..b71541b8dc 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIRBase.td @@ -6,6 +6,7 @@ #define TTMLIR_TTMLIR_DIALECT_TTIR_TTIRDIALECT_TD include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// // TTIR dialect definition. @@ -38,6 +39,6 @@ def TTIR_Dialect : Dialect { //===----------------------------------------------------------------------===// class TTIR_Op traits = []> : - Op; + Op; #endif diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 489fd2faa9..69510f93a4 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -114,8 +114,8 @@ def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpI - Some combination of the above ```llvm - #layout = #tt.layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>> - #layout1 = #tt.layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>> + #layout = #tt.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>> + #layout1 = #tt.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>> %1 = "ttir.to_layout"(%arg0, %0) : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1> ``` }]; @@ -172,8 +172,12 @@ def TTIR_DeallocOp : TTIR_Op<"dealloc"> { // TTIR top level named ops //===----------------------------------------------------------------------===// +def TwoOperands : ParamNativeOpTrait<"NOperands", "2">; +def ThreeOperands : ParamNativeOpTrait<"NOperands", "3">; +def FourOperands : ParamNativeOpTrait<"NOperands", "4">; + class TTIR_ElementwiseOp traits = []> : - TTIR_DPSOp { + TTIR_DPSOp { let description = [{ Base class for elementwise operations. Elementwise operations can take inputs with different shape, @@ -187,7 +191,7 @@ class TTIR_ElementwiseOp traits = []> : } class TTIR_ElementwiseTernaryOp traits = []> : - TTIR_ElementwiseOp { + TTIR_ElementwiseOp { let summary = "Eltwise ternary op."; let description = [{ Eltwise ternary op. @@ -210,7 +214,7 @@ def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> { } class TTIR_ElementwiseUnaryOp traits = []> : - TTIR_ElementwiseOp { + TTIR_ElementwiseOp { let summary = "Eltwise unary op."; let description = [{ Eltwise unary op. @@ -288,6 +292,20 @@ def TTIR_NegOp: TTIR_ElementwiseUnaryOp<"neg"> { }]; } +def TTIR_TanOp: TTIR_ElementwiseUnaryOp<"tan"> { + let summary = "Eltwise tan op."; + let description = [{ + Eltwise tan operation. + }]; +} + +def TTIR_TanhOp: TTIR_ElementwiseUnaryOp<"tanh"> { + let summary = "Eltwise tanh op."; + let description = [{ + Eltwise tanh operation. + }]; +} + def TTIR_ReciprocalOp : TTIR_ElementwiseUnaryOp<"reciprocal"> { let summary = "Eltwise reciprocal."; let description = [{ @@ -424,7 +442,7 @@ def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> { } class TTIR_ElementwiseBinaryOp traits = []> : - TTIR_ElementwiseOp { + TTIR_ElementwiseOp { let summary = "Eltwise binary op."; let description = [{ Eltwise binary op. @@ -502,18 +520,6 @@ def TTIR_LogicalXorOp : TTIR_ElementwiseBinaryOp<"logical_xor"> { }]; } -def TTIR_MaximumOp : TTIR_ElementwiseBinaryOp<"maximum"> { - let summary = "Eltwise maximum OP."; - let description = [{ - Calculates maximum of input tensors' values element-wise and stores result in output tensor. - - Example: - %lhs: [[3, 2, 7], [1, 4, 4]] - %rhs: [[1, 4, 2], [1, 2, 3]] - "ttir.maximum"(%lhs, %rhs, %out) -> %out: [[3, 4, 7], [1, 4, 4]] - }]; -} - def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum"> { let summary = "Eltwise minimum OP."; let description = [{ @@ -701,34 +707,56 @@ def TTIR_ConcatOp : TTIR_DPSOp<"concat"> { let hasVerifier = 1; } -def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> { - let summary = "Broadcast operation."; - let description = [{ - Broadcast op. - }]; +def TTIR_UpdateCacheOp : TTIR_DPSOp<"update_cache"> { + let summary = "Update static cache tensor."; + let description = [{ + Updates the `cache` tensor in-place with values from `input` at `update_index` and `batch_offset`. + }]; - let arguments = (ins AnyRankedTensor:$input, - AnyRankedTensor:$output, - I64ArrayAttr:$dimension, - TT_OperandConstraintArrayAttr:$operand_constraints); + let arguments = (ins AnyRankedTensor:$cache, + AnyRankedTensor:$input, + AnyRankedTensor:$update_index, + I32Attr:$batch_offset, + TT_OperandConstraintArrayAttr:$operand_constraints); - let results = (outs AnyRankedTensor:$result); + let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getCacheMutable(); } + }]; + + let hasVerifier = 1; } -// CCL ops -def TTIR_AllGatherOp : TTIR_DPSOp<"all_gather"> { - let summary = "All gather operation."; +def TTIR_FillCacheOp : TTIR_DPSOp<"fill_cache"> { + let summary = "Fill static cache tensor."; + let description = [{ + Fills the `cache` tensor in-place with values from `input` at `batch_offset`. + }]; + + let arguments = (ins AnyRankedTensor:$cache, + AnyRankedTensor:$input, + I32Attr:$batch_offset, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let results = (outs AnyRankedTensor:$result); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getCacheMutable(); } + }]; +} + +def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> { + let summary = "Broadcast operation."; let description = [{ - All gather op. + Broadcast op. }]; let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$output, - SI32Attr:$dim, + I64ArrayAttr:$dimension, TT_OperandConstraintArrayAttr:$operand_constraints); let results = (outs AnyRankedTensor:$result); @@ -736,8 +764,6 @@ def TTIR_AllGatherOp : TTIR_DPSOp<"all_gather"> { let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } }]; - - let hasVerifier = 1; } def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> { @@ -899,6 +925,8 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> { }]; let hasVerifier = 1; + + let hasFolder = 1; } def TTIR_SliceOp: TTIR_DPSOp<"slice"> { @@ -925,6 +953,33 @@ def TTIR_SliceOp: TTIR_DPSOp<"slice"> { let hasVerifier = 1; } +def TTIR_SelectOp: TTIR_DPSOp<"select"> { + let summary = "Select op."; + let description = [{ + Extracts a sub-tensor (slice) from the input tensor along a specified dimension in few steps defined by the + `begin`, `length`, and `stride` attributes. + The `begin` specifies the start index for the selected dimension of the tensor. + The `length` specifies the number of elements to extract from the input tensor along the selected dimension. + The `stride` specifies the step size for the start index. The default value is 0. 0 means no stride. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + SI32Attr:$dim, + SI32Attr:$begin, + SI32Attr:$length, + DefaultValuedOptionalAttr:$stride, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + // ANCHOR: decomposing_an_op_index_ttir def TTIR_IndexOp: TTIR_DPSOp<"index"> { let summary = "Index op."; @@ -1021,6 +1076,48 @@ def TTIR_ClampOp : TTIR_DPSOp<"clamp"> { let hasVerifier = 1; } +def TTIR_ArangeOp : TTIR_Op<"arange"> { + let summary = "Arange operation."; + let description = [{ + Tensor arange operation. + + Produces a tensor with values from `start` to `end` (exclusive) with a step size of `step`, along the dimension specified by `arange_dimension`. + + Examples: + %0 = "ttir.arange"() {start = 0 : i64, end = 5 : i64 step = 1 : i64, arange_dimension = 0 : i64} : () -> tensor<5xi64> + // %0: [0, 1, 2, 3, 4] + + %1 = "ttir.arange"() {start = 0 : i64, end = 10 : i64, step = 2 : i64, arange_dimension = 0 : i64} : () -> tensor<5xf32> + // %1: [0.0, 2.0, 4.0, 6.0, 8.0] + + %2 = "ttir.arange"() {start = 0 : i64, end = 5 : i64, step = 1 : i64, arange_dimension = 0 : i64} : () -> tensor<5x3xi64> + // %2: [ + [0, 0, 0], + [1, 1, 1], + [2, 2, 2], + [3, 3, 3], + [4, 4, 4] + ] + + %3 = "ttir.arange"() {start = 0 : i64, end = 3 : i64, step = 1 : i64, arange_dimension = 1 : i64} : () -> tensor<5x3xi64> + // %3: [ + [0, 1, 2], + [0, 1, 2], + [0, 1, 2], + [0, 1, 2], + [0, 1, 2] + ] + }]; + + let arguments = (ins SI64Attr:$start, + SI64Attr:$end, + SI64Attr:$step, + I64Attr:$arange_dimension); + + let results = (outs AnyRankedTensor:$result); + let hasVerifier = 1; +} + def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike, AllShapesMatch<["value", "result"]>]> { let summary = "Constant op."; @@ -1064,6 +1161,34 @@ def TTIR_FillOp : TTIR_DPSOp<"fill", [AllShapesMatch<["value", "result"]>]> { }]; } +def TTIR_LinearOp : TTIR_DPSOp<"linear"> { + let summary = "Linear transformation of inputs."; + let description = [{ + Produces the matmul of tensors `a` and `b` with optional addition with `bias`. + + Example: + %a = tensor.empty() : () -> tensor<10x64x32xbf16> + %b = tensor.empty() : () -> tensor<32x128xbf16> + %bias = tensor.empty() : () -> tensor<128xbf16> + %output = tensor.empty() : () -> tensor<10x64x128xbf16> + %0 = "ttir.linear"(%a, %b, %bias, %output) : (tensor<10x64x32xbf16>, tensor<32x128xbf16>, tensor<128xbf16>, tensor<10x64x128xbf16>) -> tensor<10x64x128xbf16> + }]; + + let arguments = (ins AnyRankedTensor:$a, + AnyRankedTensor:$b, + Optional:$bias, + AnyRankedTensor:$output, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + // ANCHOR: adding_an_op_matmul_ttir def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> { let summary = "Matrix multiply operation."; @@ -1099,11 +1224,10 @@ class TTIR_GenericElementwiseUnaryOp traits = []> : void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) { - assert(getNumOperands() == 2 && "Input and output operand must have the same rank"); - assert(sameRank(getOperands()) && - "Elementwise unary op must have only one input and one output operand."); + assert(sameRank(getOperation()->getOperands()) && + "Input and output operand must have the same rank"); - auto rank = mlir::cast(getOperand(0).getType()).getRank(); + auto rank = mlir::cast(getOperation()->getOperand(0).getType()).getRank(); SmallVector indexingMaps(2, builder.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes( @@ -1112,19 +1236,6 @@ class TTIR_GenericElementwiseUnaryOp traits = []> : return {builder.getAffineMapArrayAttr(indexingMaps), builder.getArrayAttr(iteratorTypes)}; } - - static bool sameRank(mlir::OperandRange operands) { - if (operands.empty()) { - return true; - } - auto rank = mlir::cast(operands[0].getType()).getRank(); - for (auto operand : operands) { - if (mlir::cast(operand.getType()).getRank() != rank) { - return false; - } - } - return true; - } }]; } @@ -1144,29 +1255,16 @@ class TTIR_GenericElementwiseBinaryOp traits = []> void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) { - assert(sameRank(getOperands()) && + assert(sameRank(getOperation()->getOperands()) && "For now all operands must have the same rank"); - auto rank = mlir::cast(getOperand(0).getType()).getRank(); - SmallVector indexingMaps(getNumOperands(), + auto rank = mlir::cast(getOperation()->getOperand(0).getType()).getRank(); + SmallVector indexingMaps(getOperation()->getNumOperands(), builder.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes( rank, builder.getAttr(IteratorType::Parallel)); return {builder.getAffineMapArrayAttr(indexingMaps), builder.getArrayAttr(iteratorTypes)}; } - - static bool sameRank(mlir::OperandRange operands) { - if (operands.empty()) { - return true; - } - auto rank = mlir::cast(operands[0].getType()).getRank(); - for (auto operand : operands) { - if (mlir::cast(operand.getType()).getRank() != rank) { - return false; - } - } - return true; - } }]; } @@ -1191,6 +1289,53 @@ def TTIR_DivOp : TTIR_GenericElementwiseBinaryOp<"div"> { }]; } +def TTIR_MaximumOp : TTIR_GenericElementwiseBinaryOp<"maximum"> { + let summary = "Eltwise maximum."; + let description = [{ + Calculates maximum of input tensors' values element-wise and stores result in output tensor. + + Example: + %lhs: [[3, 2, 7], [1, 4, 4]] + %rhs: [[1, 4, 2], [1, 2, 3]] + "ttir.maximum"(%lhs, %rhs, %out) -> %out: [[3, 4, 7], [1, 4, 4]] + }]; +} + +//===----------------------------------------------------------------------===// + +def TTIR_ScatterOp: TTIR_DPSOp<"scatter"> { + let summary = "Scatter operation"; + let description = [{ + Produces a 'result' tensor which are equal to `input` tensor except that + several slices specified by `scatter_indices` are updated with the values + `updates`. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$scatter_indices, + AnyRankedTensor:$update, + DenseI32ArrayAttr:$update_window_dims, + DenseI32ArrayAttr:$inserted_window_dims, + DenseI32ArrayAttr:$input_batching_dims, + DenseI32ArrayAttr:$scatter_indices_batching_dims, + DenseI32ArrayAttr:$scatter_dims_to_operand_dims, + I32Attr:$index_vector_dim, + BoolAttr:$indices_are_sorted, + BoolAttr:$unique_indices, + AnyRankedTensor:$output, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let regions = (region SizedRegion<1>:$update_computation); + + let results = (outs AnyRankedTensor:$result); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; +} + //===----------------------------------------------------------------------===// // TTIR region ops (ops that may appear inside of ttir.generic region) //===----------------------------------------------------------------------===// @@ -1220,4 +1365,102 @@ def TTIR_YieldOp : TTIR_Op<"yield", [Pure, ReturnLike, Terminator]> { let arguments = (ins Variadic:$values); } +//===----------------------------------------------------------------------===// +// TTIR ccl ops +//===----------------------------------------------------------------------===// + +def TTIR_AllGatherOp : TTIR_DPSOp<"all_gather"> { + let summary = "All gather operation."; + let description = [{ + All gather op. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + SI32Attr:$dim, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + +def TTIR_AllReduceOp : TTIR_DPSOp<"all_reduce"> { + let summary = "AllReduce operation."; + let description = [{ + AllReduce op. + }]; + + let arguments = (ins + Variadic:$inputs, + AnyRankedTensor:$output, + I64ElementsAttr:$replica_groups, + SI32Attr:$dim, + OptionalAttr:$channel_handle, + UnitAttr:$use_global_device_ids, + TT_ReduceTypeAttr:$reduce_type, + TT_OperandConstraintArrayAttr:$operand_constraints + ); + + let results = (outs Variadic:$results); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + +def TTIR_MeshShardOp : TTIR_DPSOp<"mesh_shard"> { + let summary = "Mesh shard operation."; + let description = [{ + MeshShard op shards the inputs (FullToShard) or concatnates the outputs (ShardToFull) for ccl ops. + + shard_direction attribute determines whether to shard or concat. + + shard_type attribute determines how to shard or concat. + manual: no sharding + replicate: all devices have identical data + maximal: only one device contains full data + devices: shard_shape determines sharded dimensions + + For example, on 2x4 mesh hardware, following op shards arg0 to 8 slices, row divided by 2 + and col divided by 4. + + %1 = "ttir.mesh_shard"(%arg0, %0) < + {... shard_direction = #tt.shard_direction, + shard_shape = #tt.grid<2x4>, + shard_type = #tt.shard_type}> : (tensor<8192x784xf32>, ...) -> tensor<4096x196xf32> + + On the other hand, this op concatnates %4 to single tensor by concatnating + one of the top row tensor with one of the bottom row tensor. + + %6 = "ttir.mesh_shard"(%4, %5) < + {..., shard_direction = #tt.shard_direction, + shard_shape = #tt.grid<2x1>, + shard_type = #tt.shard_type}> : (tensor<4096x16384xf32>, ...) -> tensor<8192x16384xf32> + }]; + + let arguments = (ins + AnyRankedTensor:$input, + AnyRankedTensor:$output, + TT_MeshShardTypeAttr:$shard_type, + TT_MeshShardDirectionAttr:$shard_direction, + TT_GridAttr:$shard_shape, + TT_OperandConstraintArrayAttr:$operand_constraints + ); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + #endif diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h index 1d88e8a657..01b6772972 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h @@ -12,7 +12,7 @@ namespace mlir { namespace tt { namespace ttir { namespace detail { -mlir::LogicalResult verifyElementwiseOp(mlir::Operation *op); +mlir::LogicalResult verifyBroadcastable(mlir::Operation *op); } // namespace detail } // namespace ttir } // namespace tt diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td index cbc0056737..a130332f0d 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td @@ -64,11 +64,13 @@ def TTIROpInterface : OpInterface<"TTIROp"> { ]; } -def TTIR_ElementwiseOpInterface : OpInterface<"ElementwiseOp"> { +def TTIR_Broadcastable : OpInterface<"Broadcastable"> { let cppNamespace = "::mlir::tt::ttir"; + let dependentTraits = [AttrSizedOperandSegments]; + let verify = [{ - return detail::verifyElementwiseOp($_op); + return detail::verifyBroadcastable($_op); }]; } @@ -105,6 +107,20 @@ def TTIR_GenericRegionOpInterface : OpInterface<"GenericRegionOp"> { /*methodBody=*/"", /*defaultImplementation=*/"" >, + StaticInterfaceMethod< + /*desc=*/[{ + Return if the given operands have the same rank. + }], + /*retTy=*/"bool", + /*methodName=*/"sameRank", + /*args=*/(ins "::mlir::OperandRange":$operands), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return llvm::all_equal(llvm::map_range(operands, [](Value operand) { + return mlir::cast(operand.getType()).getRank(); + })); + }] + > ]; } diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 63ccb0d28a..b6269f7153 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -112,4 +112,17 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { ]; } +def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> { + let summary = "Broadcast operation is folded to all the consumers."; + let description = [{ + This pass walks through the graph and folds all broadcast instructions since broadcast is supported implicitly by backend ops. + Example: + %1 = "ttir.broadcast"(%arg0) (tensor<1xf32>) -> tensor<512xf32> + %2 = "ttir.maximum"(%1, %arg1) (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32> + + This above broadcast is folded as: + %1 = "ttir.maximum"(%arg0, %arg1) (tensor<1xf32>, tensor<512xf32>) -> tensor<512xf32> + }]; +} + #endif diff --git a/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td b/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td index 4b6da4b683..ed70d7da68 100644 --- a/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td +++ b/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td @@ -180,6 +180,15 @@ def TTKernel_MulOp : TTKernel_Op<"mul"> { let arguments = (ins I32:$dst_index); } +def TTKernel_MaxOp : TTKernel_Op<"max"> { + let summary = "Max operation"; + let description = [{ + Max operation + }]; + + let arguments = (ins I32:$dst_index); +} + def TTKernel_MatmulOp : TTKernel_Op<"matmul"> { let summary = "Matmul operation"; let description = [{ @@ -333,6 +342,31 @@ def TTKernel_ReduceTileOp : TTKernel_Op<"reduce_tile"> { TTKernel_ReduceDimAttr:$reduce_dim); } +//===----------------------------------------------------------------------===// +// TTKernel SFPU operations +//===----------------------------------------------------------------------===// + +def TTKernel_MaxTilesInitOp : TTKernel_Op<"max_tile_init"> { + let summary = "Short init function"; + let description = [{ + Must be run before max_tile. + }]; + + let arguments = (ins); +} + +def TTKernel_MaxTilesOp : TTKernel_Op<"max_tile"> { + let summary = "Max operation"; + let description = [{ + Performs element-wise computation of maximum operation + DST[dst0_index] <- max(DST[dst0_index], DST[dst1_index]) + on DST register operands. The DST register buffer must be in + acquired state via *tile_regs_acquire* call. + }]; + + let arguments = (ins I32:$dst0_index, I32:$dst1_index); +} + //===----------------------------------------------------------------------===// // TTKernel CB operations //===----------------------------------------------------------------------===// @@ -503,6 +537,68 @@ def TTKernel_NocAsyncWriteBarrierOp : TTKernel_Op<"noc_async_write_barrier"> { }]; } +//===----------------------------------------------------------------------===// +// TTKernel Multicast NoC operations +//===----------------------------------------------------------------------===// + +def TTKernel_GetNocMulticastAddrOp : TTKernel_Op<"get_noc_multicast_addr"> { + let summary = "GetNocMulticastAddr"; + let description = [{ + GetNocMulticastAddr + }]; + + let arguments = (ins I32:$noc_x_start, I32:$noc_y_start, I32:$noc_x_end, I32:$noc_y_end, I32:$addr, Optional:$noc); + let results = (outs TTKernel_NocAddr:$mcastNocAddr); +} + +def TTKernel_NocAsyncWriteMulticastOnePacketOp : TTKernel_Op<"noc_async_write_multicast_one_packet"> { + let summary = "NocAsyncWriteMulticastOnePacket"; + let description = [{ + NocAsyncWriteMulticastOnePacket + this issues only a single packet with size <= NOC_MAX_BURST_SIZE (ie maximum packet size) + }]; + + let arguments = (ins I32:$srcLocalL1Addr, TTKernel_NocAddr:$dstNocAddrMulticast, I32:$size, I32:$num_dests, OptionalAttr:$linked, OptionalAttr:$multicast_path_reserve, Optional:$noc); +} + +def TTKernel_NocAsyncWriteMulticastOp : TTKernel_Op<"noc_async_write_multicast"> { + let summary = "NocAsyncWriteMulticast"; + let description = [{ + Initiates an asynchronous write from a source address in L1 memory on the + Tensix core executing this function call to a rectangular destination grid. + The destinations are specified using a uint64_t encoding referencing an + on-chip grid of nodes located at NOC coordinate range + (x_start,y_start,x_end,y_end) and a local address created using + *get_noc_multicast_addr* function. Also, *see noc_async_write_barrier*. + + The destination nodes can only be a set of Tensix cores + L1 memory address. + The destination nodes must form a rectangular grid. The destination L1 + memory address must be the same on all destination nodes. + + With this API, the multicast sender cannot be part of the multicast + destinations. If the multicast sender has to be in the multicast + destinations (i.e. must perform a local L1 write), the other API variant + *noc_async_write_multicast_loopback_src* can be used. + + Note: The number of destinations needs to be non-zero. Besides that, + there is no restriction on the number of destinations, i.e. the + multicast destinations can span the full chip. However, as mentioned + previously, the multicast source cannot be part of the destinations. So, the + maximum number of destinations is 119. + }]; + + let arguments = (ins I32:$srcLocalL1Addr, TTKernel_NocAddr:$dstNocAddrMulticast, I32:$size, I32:$num_dests, OptionalAttr:$linked, OptionalAttr:$multicast_path_reserve, Optional:$noc); +} + +def TTKernel_NocAsyncWriteMulticastLoopbackSrcOp : TTKernel_Op<"noc_async_write_multicast_loopback_src"> { + let summary = "NocAsyncWriteMulticastLoopbackSrc"; + let description = [{ + NocAsyncWriteMulticastLoopbackSrc + }]; + + let arguments = (ins I32:$srcLocalL1Addr, TTKernel_NocAddr:$dstNocAddrMulticast, I32:$size, I32:$num_dests, OptionalAttr:$linked, OptionalAttr:$multicast_path_reserve, Optional:$noc); +} + //===----------------------------------------------------------------------===// // TTKernel Misc operations //===----------------------------------------------------------------------===// diff --git a/include/ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h b/include/ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h index 3c57ca66b7..b8aee2e4ea 100644 --- a/include/ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h @@ -58,7 +58,7 @@ class L1ChainConfig { std::unordered_set &memReconfigEdges); bool isEmpty() { return opL1MemSpecs.empty(); } - void addOpL1MemSpec(OpL1MemSpec &&spec) { + void addOpL1MemSpec(OpL1MemSpec spec) { assert(state == L1ChainState::InBuild); l1ChainedOps.insert(spec.op); opL1MemSpecs.push_back(std::move(spec)); diff --git a/include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h b/include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h index f453e9a1d3..2392cd7c9c 100644 --- a/include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h @@ -8,10 +8,43 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h" #include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" namespace mlir::tt::ttnn { class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy { +public: + struct OpMemSpec { + TTNNLayoutAttr layout; + // Minimum L1 memory usage required for scheduling the op + // given the layouts of all the ops that are already scheduled. + // + uint64_t requiredL1Usage; + }; + + // This struct is holding information about the greedily choosen + // configuration of the @baseOp: 1) layouts and 2) precedence. + // + // The @layouts represents the mapping between the op and its choosen + // layout. All the ops that are included in the @layouts map must be + // either @baseOp or its operand with legal L1 Interleaved output layout + // at the moment of analyzing the @baseOp. + // + // The @precedence represents the order of the op's operands in which they + // should be scheduled. Only op's operands that are included in the @layouts + // map are included in the @precedence. + // + struct OpConfig { + Operation *baseOp; + llvm::DenseMap layouts; + llvm::SmallVector precedence; + }; + + struct L1Usage { + size_t outputL1Usage; + size_t requiredL1Usage; + }; + public: L1InterleavedPolicy( Operation *rootOp, std::vector &l1ChainConfigs, @@ -22,7 +55,71 @@ class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy { : MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts, schedule, usableL1CacheSize) {} + /** + * Retrieve the greedy OpConfig for the given base operation + * and its opsL1Usage map. + * + * @param baseOp The base operation for which the greedy configuration is + * being determined. + * @param opsL1Usage A map between the operation and its output L1 usage. All + * operations included in the opsL1Usage map must be either the baseOp or its + * operand with a legal L1 Interleaved output layout at the time of analyzing + * the baseOp. + * @return The greedy OpConfig for the baseOp. + */ + OpConfig getGreedyConfig(Operation *baseOp, + llvm::DenseMap &opsL1Usage); + void run() final; + +private: + // Check if the op is analyzable. Op is analyzable if it has at least one + // legal layout. + bool isAnalyzable(Operation *op); + + // Fetch op's DRAM layout from legalLayouts. + bool hasDRAMBufferType(Operation *op); + TTNNLayoutAttr getDRAMLayout(Operation *op); + + // Fetch op's L1 Interleaved layout from legalLayouts. + bool hasL1BufferType(Operation *op); + TTNNLayoutAttr getL1InterleavedLayout(Operation *op); + + size_t getAvailableL1CacheSize() const { + // Figure out this const based on exec data, but will be replaced + // with API. + // + constexpr float tensorL1UsageCap = 0.75; + return tensorL1UsageCap * usableL1CacheSize; + } + + // Precedence schedule map for each operation. It contains the order + // in which operands need to be executed for each op. + llvm::DenseMap> precedenceMap; + + llvm::DenseSet visitedOps; + void buildSchedule(mlir::Operation *op, func::FuncOp &func) { + + // Schedule all the precedents of the current operation + // + visitedOps.insert(op); + for (Operation *precedent : precedenceMap[op]) { + if (!visitedOps.count(precedent)) { + buildSchedule(precedent, func); + } + } + + (*schedule)[func].push_back(op); + } + + void constructSchedule(func::FuncOp &func) { + func->walk([&](Operation *op) { + if (op->hasTrait()) { + Operation *outputOp = op->getOperand(0).getDefiningOp(); + buildSchedule(outputOp, func); + } + }); + } }; } // namespace mlir::tt::ttnn diff --git a/include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h b/include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h index e8b6038154..bc6284c3a0 100644 --- a/include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h @@ -6,10 +6,10 @@ #define TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSIS_H #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h" #include "ttmlir/Dialect/TTNN/Analysis/Edge.h" #include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h" #include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h" +#include "ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h" namespace mlir::tt::ttnn { diff --git a/include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt b/include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt index cfd65fe8db..fbf68f69dd 100644 --- a/include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt +++ b/include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_doc(TTNNBase TTNNDialect src/autogen/md/Dialect/ -gen-dialect-doc) add_mlir_doc(TTNNOps TTNNOp src/autogen/md/Dialect/ -gen-op-doc) add_mlir_interface(TTNNOpModelInterface) +add_mlir_interface(TTNNWorkaroundInterface) set(LLVM_TARGET_DEFINITIONS TTNNOpsEnums.td) mlir_tablegen(TTNNOpsEnums.h.inc -gen-enum-decls) diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNBase.td b/include/ttmlir/Dialect/TTNN/IR/TTNNBase.td index b1821c8f1b..ea77d6795b 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNBase.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNBase.td @@ -6,7 +6,9 @@ #define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNDIALECT_TD include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.td" +include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td" //===----------------------------------------------------------------------===// // TTNN dialect definition. @@ -44,6 +46,9 @@ def TTNN_Dialect : Dialect { //===----------------------------------------------------------------------===// class TTNN_Op traits = []> : - Op; + Op; + +class TTNN_InplaceOp traits = []> : + Op, TTNN_OpModelInterface, TTNN_WorkaroundInterface])>; #endif diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h index e66fab65a3..457c7722bb 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h @@ -18,6 +18,7 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.h.inc" +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h" #define GET_OP_CLASSES #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h.inc" diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 910ed7dfd9..ed914cb555 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -174,6 +174,17 @@ def TTNN_AbsOp : TTNN_ElementwiseUnaryOp<"abs"> { let description = [{ Eltwise absolute operation. }]; + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } + wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { + wa::TTNNOperandWorkarounds tileLayoutWorkaround = wa::TTNNOperandWorkarounds(Layout::Tile); + return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds() + .addInputOperandWorkaround(tileLayoutWorkaround) + .addInputOperandWorkaround(tileLayoutWorkaround) + .addOutputOperandWorkaround(tileLayoutWorkaround); + } + }]; } def TTNN_CbrtOp : TTNN_ElementwiseUnaryOp<"cbrt"> { @@ -251,6 +262,20 @@ def TTNN_NegOp : TTNN_ElementwiseUnaryOp<"neg"> { }]; } +def TTNN_TanOp: TTNN_ElementwiseUnaryOp<"tan"> { + let summary = "Eltwise tan op."; + let description = [{ + Eltwise tan operation. + }]; +} + +def TTNN_TanhOp: TTNN_ElementwiseUnaryOp<"tanh"> { + let summary = "Eltwise tanh op."; + let description = [{ + Eltwise tanh operation. + }]; +} + def TTNN_ReciprocalOp : TTNN_ElementwiseUnaryOp<"reciprocal"> { let summary = "Eltwise reciprocal."; let description = [{ @@ -325,7 +350,7 @@ def TTNN_Expm1Op: TTNN_ElementwiseUnaryOp<"expm1"> { }]; } -class TTIR_ElementwiseUnaryWithFloatParameterOp traits = []> : +class TTNN_ElementwiseUnaryWithFloatParameterOp traits = []> : TTNN_ElementwiseUnaryOp { let summary = "Eltwise unary op with the float parameter."; let description = [{ @@ -345,7 +370,7 @@ class TTIR_ElementwiseUnaryWithFloatParameterOp tra ]; } -def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> { +def TTNN_LeakyReluOp : TTNN_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> { let summary = "Eltwise leaky relu operation."; let description = [{ The Leaky ReLU (Rectified Linear Unit) operation computes an element-wise @@ -552,6 +577,33 @@ def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> { let hasVerifier = 1; } +def TTNN_UpdateCacheOp : TTNN_InplaceOp<"update_cache"> { + let summary = "Update static cache tensor."; + let description = [{ + Updates the `cache` tensor in-place with values from `input` at `update_index` and `batch_offset`. + }]; + + let arguments = (ins Arg:$cache, + AnyRankedTensor:$input, + AnyRankedTensor:$update_index, + I32Attr:$batch_offset); + + let hasVerifier = 1; +} + +def TTNN_FillCacheOp : TTNN_InplaceOp<"fill_cache"> { + let summary = "Fill static cache tensor."; + let description = [{ + Fills the `cache` tensor in-place with values from `input` at `batch_offset`. + }]; + + let arguments = (ins Arg:$cache, + AnyRankedTensor:$input, + I32Attr:$batch_offset); + + let hasVerifier = 1; +} + def TTNN_SoftmaxOp : TTNN_Op<"softmax"> { let summary = "Softmax op."; let description = [{ @@ -636,6 +688,34 @@ def TTNN_SliceOp: TTNN_NamedDPSOp<"slice"> { let hasVerifier = 1; } +def TTNN_LinearOp : TTNN_NamedDPSOp<"linear"> { + let summary = "Linear transformation of inputs."; + + let description = [{ + Produces the matmul of tensors `a` and `b` with optional addition with `bias`. + + Example: + // %a = [[1., 2.]], [2., 1.]] + // %b = [[0., 1.], [1., 0.]] + // %bias = [[1.]] + "ttnn.linear"(%a, %b, %bias, %result) : (tensor<2x2xf16>, tensor<2x2xf16>, tensor<1xf16>, tensor<2x2xf16>) -> tensor<2x2xf16> + // %result = [[3., 2.], [2., 3.]] + }]; + + let arguments = (ins AnyRankedTensor:$a, + AnyRankedTensor:$b, + Optional:$bias, + AnyRankedTensor:$output); + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + + // ANCHOR: adding_an_op_matmul_ttnn def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul"> { let arguments = (ins AnyRankedTensor:$a, @@ -741,9 +821,7 @@ def TTNN_ClampOp : TTNN_Op<"clamp"> { let hasVerifier = 1; } -// Note: NoMemoryEffect is used to indicate that operation can be removed if it is not used. -// Removal of this operation is done by the dead code elimination pass (RemoveDeadValuesPass). -def TTNN_EmptyOp : TTNN_Op<"empty", [NoMemoryEffect]> { +def TTNN_EmptyOp : TTNN_Op<"empty"> { let summary = "Empty op."; let description = [{ Tensor empty operation @@ -756,9 +834,43 @@ def TTNN_EmptyOp : TTNN_Op<"empty", [NoMemoryEffect]> { OptionalAttr:$memory_config); let results = (outs AnyRankedTensor:$result); + let extraClassDeclaration = [{ + wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { + wa::TTNNOperandWorkarounds rowMajorLayoutWorkaround = wa::TTNNOperandWorkarounds(Layout::RowMajor); + return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds() + .addOutputOperandWorkaround(rowMajorLayoutWorkaround); + } + }]; + let hasVerifier = 1; } +def TTNN_ArangeOp : TTNN_Op<"arange"> { + let summary = "Arange operation."; + let description = [{ + Tensor arange operation. + + Produces a (1, 1, 1, N)-shaped tensor with values from `start` to `end` (exclusive) with a step size of `step`. + + Examples: + %0 = "ttnn.arange"() {start = 0 : i64, end = 5 : i64 step = 1 : i64} : () -> tensor<1x1x1x5xi64> + // %0: [[[[0, 1, 2, 3, 4]]]] + + %1 = "ttnn.arange"() {start = 0 : i64, end = 10 : i64, step = 2 : i64} : () -> tensor<1x1x1x5xf32> + // %1: [[[[0.0, 2.0, 4.0, 6.0, 8.0]]]] + }]; + + let arguments = (ins I64Attr:$start, + I64Attr:$end, + I64Attr:$step, + OptionalAttr:$dtype, + Optional:$device, + OptionalAttr:$memory_config); + + let results = (outs AnyRankedTensor:$result); + let hasVerifier = 1; +} + def TTNN_FullOp : TTNN_Op<"full"> { let summary = "Full op."; let description = [{ @@ -806,6 +918,13 @@ def TTNN_AllGatherOp: TTNN_Op<"all_gather"> { let hasVerifier = 1; } +def TTNN_ScatterOp: TTNN_ElementwiseBinaryOp<"scatter"> { + let summary = "Scatter op."; + let description = [{ + Embeds the values of the 'update' tensor into 'input' at the given index and puts the value in the 'output' tensor. + }]; +} + def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> { let summary = "Reduce scatter op."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h index 944157846d..790c49228c 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h @@ -14,14 +14,21 @@ namespace mlir::tt::ttnn { -inline bool isSystemBufferType(mlir::tt::ttnn::BufferType bufferType) { - return bufferType == mlir::tt::ttnn::BufferType::SystemMemory; +inline bool isSystemBufferType(BufferType bufferType) { + return bufferType == BufferType::SystemMemory; } -inline bool isDeviceBufferType(mlir::tt::ttnn::BufferType bufferType) { - return bufferType == mlir::tt::ttnn::BufferType::L1 || - bufferType == mlir::tt::ttnn::BufferType::DRAM || - bufferType == mlir::tt::ttnn::BufferType::L1Small; +inline bool isDeviceBufferType(BufferType bufferType) { + return bufferType == BufferType::L1 || bufferType == BufferType::DRAM || + bufferType == BufferType::L1Small; +} + +inline bool isL1BufferType(BufferType bufferType) { + return bufferType == BufferType::L1; +} + +inline bool isDRAMBufferType(BufferType bufferType) { + return bufferType == BufferType::DRAM; } inline bool isShardedMemoryLayout(TensorMemoryLayout layout) { @@ -29,6 +36,7 @@ inline bool isShardedMemoryLayout(TensorMemoryLayout layout) { layout == TensorMemoryLayout::WidthSharded || layout == TensorMemoryLayout::BlockSharded; } + } // namespace mlir::tt::ttnn #define GET_ATTRDEF_CLASSES diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td index eebe601100..e483b07bf2 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td @@ -81,9 +81,9 @@ def TTNN_MemoryConfigAttr : TTNN_Attr<"MemoryConfig", "memory_config"> { TTNN memory config attribute }]; - let parameters = (ins AttrParameter<"TensorMemoryLayoutAttr", "">:$tensorMemoryLayout, - AttrParameter<"BufferTypeAttr", "">:$bufferType, - AttrParameter<"ShardSpecAttr", "">:$shardSpec); + let parameters = (ins AttrParameter<"BufferTypeAttr", "">:$bufferType, + AttrParameter<"ShardSpecAttr", "">:$shardSpec, + OptionalParameter<"TensorMemoryLayoutAttr">:$tensorMemoryLayout); let assemblyFormat = "`<` params `>`"; @@ -92,6 +92,9 @@ def TTNN_MemoryConfigAttr : TTNN_Attr<"MemoryConfig", "memory_config"> { { return this->getShardSpec().getShardShape().getShape(); } + + MemoryConfigAttr withBufferType(::mlir::MLIRContext *context, BufferType bufferType); + MemoryConfigAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout); }]; } @@ -109,12 +112,19 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> { let summary = "Tensor encoding attribute used for types in ttnn"; let description = [{ Layout attribute in ttnn. This attribute is used to encode different information about tensor memory layout. + Here is how tensor will look like after layout tensor<32x32x64xf32, #ttnn.ttnn_layout> + Lets break down what each parameter means: + - linear: An affine map that defines how the logical tensor dimensions map to physical space. + - grid: The grid shape (of tensix cores) where tensor is divided onto. + - memref: A memref is used to describe shard size and memory space. Shard size is calculated by dividing the tensor size by grid size. + - mem_layout: The layout of the tensor in memory. For tensor on host it should be None. For tensor on device + it can be interleaved or sharded. }]; let parameters = (ins AttrParameter<"AffineMap", "An affine map that defines how the logical tensor dimensions map to a grid shape.">:$linear, AttrParameter<"GridAttr", "The grid shape that this tensor is divided onto.">:$grid, AttrParameter<"MemRefType", "A memref that describes the physical footprint allocation of the shard. It must also have a shape with rank equal to grid.">:$memref, - DefaultValuedParameter<"TensorMemoryLayout", "TensorMemoryLayout::None", "The layout of the tensor in memory.">:$mem_layout); + OptionalParameter<"TensorMemoryLayoutAttr", "TTNN tensor memory layout">:$mem_layout); let assemblyFormat = "`<` $linear`,` $grid`,` $memref (`,` $mem_layout^)? `>`"; let extraClassDeclaration = [{ static TTNNLayoutAttr get(::mlir::MLIRContext *context, @@ -122,38 +132,44 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> { Type elementType, BufferType bufferType, GridAttr grid, - TensorMemoryLayout memoryLayout, + TensorMemoryLayoutAttr memoryLayoutAttr = nullptr, + ArrayRef> collapseIntervals = {{0, -1}}); + + TTNNLayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); + TTNNLayoutAttr withGrid(::mlir::MLIRContext *context, + RankedTensorType ty, + GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); - uint64_t getShardSizeInBytes() const; - BufferType getBufferType() const; - TTNNLayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); - TTNNLayoutAttr withGrid(::mlir::MLIRContext *context, - RankedTensorType ty, - GridAttr grid, - ArrayRef> collapseIntervals = {{0, -1}}); - TTNNLayoutAttr withElementType(::mlir::MLIRContext *context, Type elementType); - TTNNLayoutAttr withBufferType(::mlir::MLIRContext *context, BufferType bufferType); - TTNNLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout); - TTNNLayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector shardShape); - - bool isSystemBufferType() const { return ::mlir::tt::ttnn::isSystemBufferType(getBufferType()); } - bool isDeviceBufferType() const { return ::mlir::tt::ttnn::isDeviceBufferType(getBufferType()); } - bool hasShardedTensorMemoryLayout() const; - bool hasShardedL1TensorMemoryLayout() const; - bool hasInterleavedL1TensorMemoryLayout() const; - bool isTiled() const; - Type getElementType() const; - Type getScalarElementType() const; - DataType getDataTypeFromMemRef() const; - uint64_t getElementSizeBytes() const; - int64_t getTensorSizeInBytes(ArrayRef tensorShape, ::mlir::tt::DeviceAttr device) const; - llvm::SmallVector getStride(ArrayRef logicalShape) const; - llvm::SmallVector getPhysicalShape(ArrayRef logicalShape) const; - llvm::SmallVector getShardShape(bool convertTileToScalar = true) const; - AffineMap replaceMemoryMapSymbolsWithShardShape(AffineMap physicalMemoryMap) const; - AffineMap projectOnto(AffineMap linearMap, AffineMap physicalMemoryMap) const; - AffineMap getIdentityTileLinearMap() const; - llvm::SmallVector getTiledShape(ArrayRef logicalTensorShape) const; + TTNNLayoutAttr withElementType(::mlir::MLIRContext *context, Type elementType); + TTNNLayoutAttr withBufferType(::mlir::MLIRContext *context, BufferType bufferType); + TTNNLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayoutAttr memLayoutAttr); + TTNNLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout); + TTNNLayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector shardShape); + + bool isSystemBufferType() const { return ::mlir::tt::ttnn::isSystemBufferType(getBufferType()); } + bool isDeviceBufferType() const { return ::mlir::tt::ttnn::isDeviceBufferType(getBufferType()); } + bool isTiled() const; + bool hasShardedTensorMemoryLayout() const; + bool hasShardedL1TensorMemoryLayout() const; + bool hasInterleavedL1TensorMemoryLayout() const; + bool hasInterleavedDRAMTensorMemoryLayout() const; + bool hasDRAMBufferType() const; + bool hasL1BufferType() const; + Layout getLayout() const; + std::optional getMemLayoutOpt() const; + Type getElementType() const; + Type getScalarElementType() const; + uint64_t getShardSizeInBytes() const; + BufferType getBufferType() const; + DataType getDataType() const; + uint64_t getElementSizeBytes() const; + int64_t getTensorSizeInBytes(ArrayRef tensorShape, ::mlir::tt::DeviceAttr device) const; + llvm::SmallVector getStride(ArrayRef logicalShape) const; + llvm::SmallVector getShardShape() const; + llvm::SmallVector getScalarShardShape() const; + AffineMap getIdentityTileLinearMap() const; + llvm::SmallVector getTiledShape(ArrayRef logicalTensorShape) const; + AffineMap replaceMemoryMapSymbolsWithShardShape(AffineMap physicalMemoryMap) const; }]; } diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td index 1b580a3a8b..0dfe811965 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td @@ -21,7 +21,6 @@ def TTNN_Layout : I32EnumAttr<"Layout", "TTNN Layout", let cppNamespace = "::mlir::tt::ttnn"; } -def TTNN_TensorMemoryLayout_None : I32EnumAttrCase<"None", 0, "none">; def TTNN_TensorMemoryLayout_Interleaved : I32EnumAttrCase<"Interleaved", 1, "interleaved">; def TTNN_TensorMemoryLayout_SingleBank : I32EnumAttrCase<"SingleBank", 2, "single_bank">; def TTNN_TensorMemoryLayout_HeightSharded : I32EnumAttrCase<"HeightSharded", 3, "height_sharded">; @@ -30,7 +29,6 @@ def TTNN_TensorMemoryLayout_BlockSharded : I32EnumAttrCase<"BlockSharded", 5, "b def TTNN_TensorMemoryLayout : I32EnumAttr<"TensorMemoryLayout", "TTNN Tensor Memory Layout", [ - TTNN_TensorMemoryLayout_None, TTNN_TensorMemoryLayout_Interleaved, TTNN_TensorMemoryLayout_SingleBank, TTNN_TensorMemoryLayout_HeightSharded, diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h new file mode 100644 index 0000000000..a6cdd5c1d7 --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#ifndef TTMLIR_DIALECT_TTNN_IR_TTNNWORKAROUNDINTERFACE_H +#define TTMLIR_DIALECT_TTNN_IR_TTNNWORKAROUNDINTERFACE_H + +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h" + +#include "mlir/IR/Operation.h" + +namespace mlir::tt::ttnn::wa { +// Verifies the TTNNWorkaroundInterface +mlir::LogicalResult verifyTTNNWorkaroundInterface(mlir::Operation *op); +} // namespace mlir::tt::ttnn::wa + +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h.inc" + +#endif diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td new file mode 100644 index 0000000000..c1ce55cd99 --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.td @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_TD +#define TTMLIR_TTMLIR_DIALECT_TTNN_IR_TTNN_WORKAROUND_INTERFACE_TD + +include "mlir/IR/OpBase.td" + +// This interface is used to specify workarounds for TTNN operations. +def TTNN_WorkaroundInterface : OpInterface<"TTNNWorkaroundInterface"> { + let cppNamespace = "::mlir::tt::ttnn::wa"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns the workarounds associated with each operand and result of this operation. + If the operation is a Destination-Passing Style (DPS) operation, the same workarounds + must apply to both the DPS initial operands and the operation results. These constraints + are verified through the interface verifier. + + For example, consider the following ttnn operations: + %0 = "ttnn.empty"() : () -> tensor<1x1xf32> + %1 = "ttnn.abs"(%arg0, %0) : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32> + + In this example: + - The ttnn.abs operation has two input operand workarounds. + - It has one output operand workaround. + - The output workaround must match the workaround for the second input operand, + ensuring consistency as required by the DPS pattern. + }], + /*retTy=*/"TTNNOperandsWorkarounds", + /*methodName=*/"getOperandsWorkarounds", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Return default empty workarounds for all input and output operands + return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(this->getOperation()); + }] + >, + ]; + + let verify = [{ + return verifyTTNNWorkaroundInterface($_op); + }]; +} + +#endif diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h new file mode 100644 index 0000000000..4122b0ca03 --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h @@ -0,0 +1,175 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTNN_IR_TTNNWORKAROUNDS_H +#define TTMLIR_DIALECT_TTNN_IR_TTNNWORKAROUNDS_H + +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" + +#include "mlir/IR/Operation.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +#include + +namespace mlir::tt::ttnn::wa { +using TensorLayoutWorkaround = std::optional; +using TensorBufferTypeWorkaround = std::optional; +using TensorMemoryLayoutWorkaround = std::optional; + +// Struct that encapsulates operand workarounds. +// It contains tensor layout, tensor buffer type and tensor memory layout +// workarounds. +struct TTNNOperandWorkarounds { + // Tensor layout workaround. + TensorLayoutWorkaround tensorLayoutWorkaround; + + // Tensor buffer type workaround. + TensorBufferTypeWorkaround tensorBufferTypeWorkaround; + + // Tensor memory layout workaround. + TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround; + + TTNNOperandWorkarounds() = default; + + // Constructor that takes tensor layout, tensor buffer type and tensor memory. + TTNNOperandWorkarounds( + TensorLayoutWorkaround tensorLayoutWorkaround, + TensorBufferTypeWorkaround tensorBufferTypeWorkaround, + TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround) + : tensorLayoutWorkaround(tensorLayoutWorkaround), + tensorBufferTypeWorkaround(tensorBufferTypeWorkaround), + tensorMemoryLayoutWorkaround(tensorMemoryLayoutWorkaround) {} + + // Constructor that takes tensor layout workaround and sets the other + // workarounds to nullopt. + TTNNOperandWorkarounds(TensorLayoutWorkaround tensorLayoutWorkaround) + : TTNNOperandWorkarounds(tensorLayoutWorkaround, std::nullopt, + std::nullopt) {} + + // Constructor that takes tensor buffer type workaround and sets the other + // workarounds to nullopt. + TTNNOperandWorkarounds(TensorBufferTypeWorkaround tensorBufferTypeWorkaround) + : TTNNOperandWorkarounds(std::nullopt, tensorBufferTypeWorkaround, + std::nullopt) {} + + // Constructor that takes tensor memory layout workaround and sets the other + // workarounds to nullopt. + TTNNOperandWorkarounds( + TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround) + : TTNNOperandWorkarounds(std::nullopt, std::nullopt, + tensorMemoryLayoutWorkaround) {} + + // Operand workarounds factory methods. + static TTNNOperandWorkarounds createEmptyTTNNOperandWorkarounds(); + + // Equality operator. + bool operator==(const TTNNOperandWorkarounds &rhs) const { + return tensorLayoutWorkaround == rhs.tensorLayoutWorkaround && + tensorBufferTypeWorkaround == rhs.tensorBufferTypeWorkaround && + tensorMemoryLayoutWorkaround == rhs.tensorMemoryLayoutWorkaround; + } + + // Inequality operator. + bool operator!=(const TTNNOperandWorkarounds &rhs) const { + return !(*this == rhs); + } + + // Returns true if any of the workarounds is set. + bool hasAnyWorkaround() const { + return tensorLayoutWorkaround || tensorBufferTypeWorkaround || + tensorMemoryLayoutWorkaround; + } +}; + +// Struct that encapsulates the result of applying the workarounds. +// It contains the target tensor layout, buffer type and tensor memory layout +// results and a flag indicating whether the workarounds were applied. +struct WorkaroundResult { + // Target tensor layout. + std::pair targetTensorLayoutResult; + + // Target tensor buffer type. + std::pair targetTensorBufferTypeResult; + + // Target tensor memory layout. Can be nullopt for tensors on host. + std::pair, bool> + targetTensorMemoryLayoutResult; + + // Returns true if any of the workarounds were applied. + bool modified() const { + return targetTensorLayoutResult.second || + targetTensorBufferTypeResult.second || + targetTensorMemoryLayoutResult.second; + } +}; + +// Apply the operand workarounds to the layout attribute that contains +// tensor layout, buffer type and tensor memory layout arguments. +// Returns the result of applying the workarounds. +WorkaroundResult applyWorkarounds(const TTNNOperandWorkarounds &workaround, + const TTNNLayoutAttr &inputLayoutAttr); + +// Class that encapsulates operands workarounds. +// It contains input and output workarounds for operands. +class TTNNOperandsWorkarounds { +public: + // Returns input operand workarounds. + llvm::ArrayRef getInputOperandWorkarounds() const { + return inputOperandWorkarounds; + } + + // Returns output operand workarounds. + llvm::ArrayRef getOutputOperandWorkarounds() const { + return outputOperandWorkarounds; + } + + // Adds input operand workaround. + TTNNOperandsWorkarounds & + addInputOperandWorkaround(TTNNOperandWorkarounds inputOperandWorkaround) { + inputOperandWorkarounds.emplace_back(inputOperandWorkaround); + return *this; + } + + // Adds output operand workaround. + TTNNOperandsWorkarounds & + addOutputOperandWorkaround(TTNNOperandWorkarounds outputOperandWorkaround) { + outputOperandWorkarounds.emplace_back(outputOperandWorkaround); + return *this; + } + + // Operands workarounds factory method. + static TTNNOperandsWorkarounds + createEmptyTTNNOperandsWorkarounds(int inputSize, int outputSize); + + // Operands workarounds factory method. + static TTNNOperandsWorkarounds createEmptyTTNNOperandsWorkarounds() { + return createEmptyTTNNOperandsWorkarounds(0, 0); + } + + // Operands workarounds factory method. + static TTNNOperandsWorkarounds + createEmptyTTNNOperandsWorkarounds(Operation *op); + +private: + // Default constructor with no workarounds. + TTNNOperandsWorkarounds() {} + + // Constructor that takes input and output workarounds for operands. + TTNNOperandsWorkarounds( + llvm::SmallVector inputOperandWorkarounds, + llvm::SmallVector outputOperandWorkarounds) + : inputOperandWorkarounds(std::move(inputOperandWorkarounds)), + outputOperandWorkarounds(std::move(outputOperandWorkarounds)) {} + + // Workarounds for input operands. + llvm::SmallVector inputOperandWorkarounds; + + // Workarounds for output operands. + llvm::SmallVector outputOperandWorkarounds; +}; + +} // namespace mlir::tt::ttnn::wa + +#endif diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index 3466f02d75..900e127fb3 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -5,8 +5,8 @@ #ifndef TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H #define TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H -#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h" -#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" +#include "ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h" +#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" #include "mlir/Pass/PassOptions.h" @@ -20,7 +20,7 @@ struct TTIRToTTNNBackendPipelineOptions // configuration for max performance. If this option is false, skip running // Optimizer pass, thus leaving all ops on default configuration. Option optimizerPassEnabled{ - *this, "enable-optimizer", + *this, OptionNames::optimizerPassEnabled, llvm::cl::desc("Determine and set max valid grid for Op execution."), llvm::cl::init(false)}; @@ -38,7 +38,7 @@ struct TTIRToTTNNBackendPipelineOptions // Option, InputLayoutOverrideParser> overrideInputLayout{ - *this, "insert-memreconfig", + *this, OptionNames::overrideInputLayout, llvm::cl::desc( "Manually insert memory reconfig op for specific op's operand."), llvm::cl::init(llvm::StringMap())}; @@ -72,21 +72,21 @@ struct TTIRToTTNNBackendPipelineOptions Option, OutputLayoutOverrideParser> overrideOutputLayout{ - *this, "override-output-layout", + *this, OptionNames::overrideOutputLayout, llvm::cl::desc("Override output tensor layout for specific ops."), llvm::cl::init(llvm::StringMap())}; // If this option is true, run memory layout analysis. // Option memoryLayoutAnalysisEnabled{ - *this, "memory-layout-analysis-enabled", + *this, OptionNames::memoryLayoutAnalysisEnabled, llvm::cl::desc("Enable memory layout optimization."), llvm::cl::init(false)}; // If this option is true, insert memory reconfiguration ops. // Option memReconfigEnabled{ - *this, "memreconfig-enabled", + *this, OptionNames::memReconfigEnabled, llvm::cl::desc("Memory layout reconfiguration pass."), llvm::cl::init(true)}; @@ -94,7 +94,7 @@ struct TTIRToTTNNBackendPipelineOptions // Option memoryLayoutAnalysisPolicy{ - *this, "memory-layout-analysis-policy", + *this, OptionNames::memoryLayoutAnalysisPolicy, llvm::cl::desc("Specify policy for memory layout analysis."), llvm::cl::init(MemoryLayoutAnalysisPolicyType::DFSharding)}; @@ -102,7 +102,7 @@ struct TTIRToTTNNBackendPipelineOptions // against. // Option systemDescPath{ - *this, "system-desc-path", + *this, OptionNames::systemDescPath, llvm::cl::desc( "Pass in a system descriptor flatbuffer to compile against."), llvm::cl::init("")}; @@ -111,19 +111,26 @@ struct TTIRToTTNNBackendPipelineOptions // legal layout analysis. // Option maxLegalLayouts{ - *this, "max-legal-layouts", + *this, OptionNames::maxLegalLayouts, llvm::cl::desc("Override maximum number of sharded layouts for legal " "layout analysis."), llvm::cl::init(64)}; ListOption meshShape{ - *this, "mesh-shape", llvm::cl::desc("Set the multi-device mesh shape.")}; + *this, OptionNames::meshShape, + llvm::cl::desc("Set the multi-device mesh shape.")}; Option rowMajorEnabled{ *this, "row-major-enabled", llvm::cl::desc( "Enable row major layout generation in legal layout analysis."), llvm::cl::init(false)}; + + // Option to enable/disable the workaround pass. + // + Option workaroundPassEnabled{*this, "enable-workaround-pass", + llvm::cl::desc("Enable workaround pass."), + llvm::cl::init(false)}; }; // TTIR to EmitC pipeline options. diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td index c29d01f7e4..13253d131d 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td @@ -28,4 +28,51 @@ def TTNNLayout : Pass<"ttnn-layout", "::mlir::ModuleOp"> { }]; } +def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> { + let summary = "Apply TTNN workarounds to the IR."; + let description = [{ + This pass applies necessary TTNN workarounds to the IR in order to create + a valid and functional IR that can be executed on the hardware. + }]; +} + +def TTNNCreateInputGenerators: Pass<"ttnn-create-input-gens", "::mlir::ModuleOp"> { + let summary = "Create input generators for the forward functions."; + let description = [{ + This pass creates input generators for the "forward" functions. It + additionally creates a main function to run the forward function with the + generated inputs. + + The pass is useful for EmitC path. By creating input generators before + converting to Emitc Dialect, followed by transformation to C++ code, the + resulting code won't require any edits to run. + + Given a forward function like this: + + ``` + func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = "ttnn.add"(%arg0, %arg1) : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0 : tensor<32x32xbf16> + } + ``` + + The pass will create two function like this: + + ``` + func.func @createInputsFor_add() -> (tensor<32x32xbf16>, tensor<32x32xbf16>) { + %0 = "ttnn.empty"() <{shape = #ttnn.shape<32x32>}> : () -> tensor<32x32xbf16> + %1 = "ttnn.empty"() <{shape = #ttnn.shape<32x32>}> : () -> tensor<32x32xbf16> + return %0, %1 : tensor<32x32xbf16>, tensor<32x32xbf16> + } + + func.func @main() -> i32 { + %0:2 = call @createInputsFor_add() : () -> (tensor<32x32xbf16>, tensor<32x32xbf16>) + %1 = call @add(%0#0, %0#1) : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %c0_i32 = arith.constant 0 : i32 + return %c0_i32 : i32 + } + ``` + }]; +} + #endif diff --git a/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h b/include/ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h similarity index 71% rename from include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h rename to include/ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h index 16fafe551a..5275e2340d 100644 --- a/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h +++ b/include/ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h @@ -2,8 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 -#ifndef TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSISPARAMS_H -#define TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSISPARAMS_H +#ifndef TTMLIR_DIALECT_TTNN_UTILS_MEMORYLAYOUTANALYSISPARAMS_H +#define TTMLIR_DIALECT_TTNN_UTILS_MEMORYLAYOUTANALYSISPARAMS_H #include #include @@ -27,21 +27,26 @@ struct MemoryLayoutAnalysisPolicyTypeParser return false; } - static void print(llvm::raw_ostream &os, - const MemoryLayoutAnalysisPolicyType &value) { - llvm::StringRef policy; + static std::string toString(const MemoryLayoutAnalysisPolicyType &value) { + std::string res; switch (value) { case MemoryLayoutAnalysisPolicyType::DFSharding: - policy = "DFSharding"; + res += "DFSharding"; break; case MemoryLayoutAnalysisPolicyType::L1Interleaved: - policy = "L1Interleaved"; + res += "L1Interleaved"; break; } - os << "memory-layout-analysis-policy=" << policy << "\n"; + return res; + } + + static void print(llvm::raw_ostream &os, + const MemoryLayoutAnalysisPolicyType &value) { + os << "memory-layout-analysis-policy=" + << MemoryLayoutAnalysisPolicyTypeParser::toString(value) << "\n"; } }; } // namespace mlir::tt -#endif // TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSISPARAMS_H +#endif // TTMLIR_DIALECT_TTNN_UTILS_MEMORYLAYOUTANALYSISPARAMS_H diff --git a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h index 2a71386f85..b13d375647 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h +++ b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h @@ -5,58 +5,109 @@ #ifndef TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H #define TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H -#include +#include +#include +#include -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h" +#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" namespace mlir::tt::ttnn { -struct OutputLayoutOverrideParams { - std::optional> grid; - std::optional bufferType; - std::optional - tensorMemoryLayout; // INTERLEAVED / SHARDED etc... - std::optional memoryLayout; // ROW_MAJOR / TILE - std::optional dataType; - - // Check if all layout parameters that are generated in LegalLayoutAnalysis - // are overridden. DataType is the only that is not. - bool fullLayoutOverride() const { - return grid.has_value() && bufferType.has_value() && - tensorMemoryLayout.has_value() && memoryLayout.has_value(); - } -}; - -struct InputLayoutOverrideParams { - SmallVector operandIdxes; -}; - -struct OutputLayoutOverrideParser - : public llvm::cl::parser> { +class OptimizerOverridesHandler { public: - OutputLayoutOverrideParser(llvm::cl::Option &opt) - : llvm::cl::parser>(opt) {} + OptimizerOverridesHandler() {}; + ~OptimizerOverridesHandler() {}; - bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, - llvm::StringMap &value); + // Setters for the overrides + // These are used to enable/disable the optimizer passes + void setEnableOptimizer(bool); + // These are used to enable/disable the memory configurations + void setMemoryReconfig(bool); + void setEnableMemoryLayoutAnalysis(bool); + void setEnableMemoryLayoutAnalysisPolicy(bool); + void setMemoryLayoutAnalysisPolicy(MemoryLayoutAnalysisPolicyType); + // These are used to set the input/output layout overrides + void setInputLayoutOverrides(llvm::StringMap &); + void setOutputLayoutOverrides(llvm::StringMap &); + // These are used to add system descriptor path + void setSystemDescPath(std::string); + // These are used to set the maximum number of legal layouts for grid analysis + void setMaxLegalLayouts(int64_t); + // These are used to set the mesh shape + void setMeshShape(std::vector); - static void print(llvm::raw_ostream &os, - const llvm::StringMap &value); -}; + // Getters for the overrides + // These are used to get the current state of the optimizer passes + bool getEnableOptimizer() const; + // These are used to get the current state of the memory configurations + bool getMemoryReconfig() const; + bool getEnableMemoryLayoutAnalysis() const; + bool getEnableMemoryLayoutAnalysisPolicy() const; + MemoryLayoutAnalysisPolicyType getMemoryLayoutAnalysisPolicy() const; + // These are used to get the current input/output layout overrides + llvm::StringMap getInputLayoutOverrides() const; + llvm::StringMap getOutputLayoutOverrides() const; + // These are used to get the current system descriptor path + std::string getSystemDescPath() const; + // These are used to get the current maximum number of legal layouts for grid + // analysis + int64_t getMaxLegalLayouts() const; + // These are used to get the current mesh shape + std::vector getMeshShape() const; -struct InputLayoutOverrideParser - : public llvm::cl::parser> { -public: - InputLayoutOverrideParser(llvm::cl::Option &opt) - : llvm::cl::parser>(opt) {} + // Method that converts the overrides to a string + std::string toString() const; + + // Fill input/output layout overrides maps. + // This is used from tt-forge frontend where we define and compile the models. + void addInputLayoutOverride(StringRef, InputLayoutOverrideParams); + void addInputLayoutOverride(StringRef, SmallVector &); + void addOutputLayoutOverride(StringRef, OutputLayoutOverrideParams); + void addOutputLayoutOverride(StringRef, SmallVector &, BufferType, + TensorMemoryLayout, tt::ttnn::Layout, + tt::DataType); + + // Wrapper methods we use to expose the adders to the python bindings + std::unordered_map + getInputLayoutOverridesPybindWrapper() const; + std::unordered_map + getOutputLayoutOverridesPybindWrapper() const; + + // Wrapper methods we use to expose the adders to the python bindings + void addInputLayoutOverridePybindWrapper(std::string, std::vector &); + void addOutputLayoutOverridePybindWrapper(std::string, std::vector &, + BufferType, TensorMemoryLayout, + tt::ttnn::Layout, tt::DataType); + +private: + // Flags for enabling/disabling the optimizer passes + bool enableOptimizer = false; + + // Flags for enabling/disabling the memory configurations + bool enableMemoryReconfig = true; + bool enableMemoryLayoutAnalysis = false; + + // Input layout overrides + llvm::StringMap inputLayoutOverrides; + + // Output layout overrides + llvm::StringMap outputLayoutOverrides; + + // Memory layout analysis policy + bool enableMemoryLayoutAnalysisPolicy = false; + MemoryLayoutAnalysisPolicyType memoryLayoutAnalysisPolicy; + + // System descriptor path + std::string systemDescPath; + + // Maximum number of legal layouts for grid analysis + int64_t maxLegalLayouts = 0; - bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, - llvm::StringMap &value); + // Mesh shape + std::vector meshShape; - static void print(llvm::raw_ostream &os, - const llvm::StringMap &value); -}; +}; // class OptimizerOverridesHandler } // namespace mlir::tt::ttnn diff --git a/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h new file mode 100644 index 0000000000..cd2d3585f8 --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h @@ -0,0 +1,167 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H +#define TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H + +#include + +#include + +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" + +namespace mlir::tt::ttnn { + +struct OptionNames { + + static constexpr StringRef optimizerPassEnabled = "enable-optimizer"; + static constexpr StringRef overrideInputLayout = "insert-memreconfig"; + static constexpr StringRef overrideOutputLayout = "override-output-layout"; + static constexpr StringRef memoryLayoutAnalysisEnabled = + "memory-layout-analysis-enabled"; + static constexpr StringRef memReconfigEnabled = "memreconfig-enabled"; + static constexpr StringRef memoryLayoutAnalysisPolicy = + "memory-layout-analysis-policy"; + static constexpr StringRef systemDescPath = "system-desc-path"; + static constexpr StringRef maxLegalLayouts = "max-legal-layouts"; + static constexpr StringRef meshShape = "mesh-shape"; +}; + +struct OutputLayoutOverrideParams { + std::optional> grid; + std::optional bufferType; + std::optional + tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + std::optional memoryLayout; // ROW_MAJOR / TILE + std::optional dataType; + + // Check if all layout parameters that are generated in LegalLayoutAnalysis + // are overridden. DataType is the only that is not. + bool fullLayoutOverride() const { + return grid.has_value() && bufferType.has_value() && + tensorMemoryLayout.has_value() && memoryLayout.has_value(); + } + + bool operator==(const OutputLayoutOverrideParams rhs) const { + if (grid.has_value() != rhs.grid.has_value()) { + return false; + } + + if (grid.has_value() && rhs.grid.has_value()) { + if (grid.value().size() != rhs.grid.value().size()) { + return false; + } + for (std::size_t i = 0; i < grid.value().size(); i++) { + if (grid.value()[i] != rhs.grid.value()[i]) { + return false; + } + } + } + + if (bufferType.has_value() != rhs.bufferType.has_value()) { + return false; + } + + if (bufferType.has_value() && rhs.bufferType.has_value()) { + if (bufferType.value() != rhs.bufferType.value()) { + return false; + } + } + + if (tensorMemoryLayout.has_value() != rhs.tensorMemoryLayout.has_value()) { + return false; + } + + if (tensorMemoryLayout.has_value() && rhs.tensorMemoryLayout.has_value()) { + if (tensorMemoryLayout.value() != rhs.tensorMemoryLayout.value()) { + return false; + } + } + + if (memoryLayout.has_value() != rhs.memoryLayout.has_value()) { + return false; + } + + if (memoryLayout.has_value() && rhs.memoryLayout.has_value()) { + if (memoryLayout.value() != rhs.memoryLayout.value()) { + return false; + } + } + + if (dataType.has_value() != rhs.dataType.has_value()) { + return false; + } + + if (dataType.has_value() && rhs.dataType.has_value()) { + if (dataType.value() != rhs.dataType.value()) { + return false; + } + } + + return true; + } + + bool operator!=(const OutputLayoutOverrideParams &rhs) const { + return !(*this == rhs); + } +}; + +struct InputLayoutOverrideParams { + + SmallVector operandIdxes; + + bool operator==(const InputLayoutOverrideParams &rhs) const { + if (operandIdxes.size() != rhs.operandIdxes.size()) { + return false; + } + for (std::size_t i = 0; i < operandIdxes.size(); i++) { + if (operandIdxes[i] != rhs.operandIdxes[i]) { + return false; + } + } + return true; + } + + bool operator!=(const InputLayoutOverrideParams &rhs) const { + return !(*this == rhs); + } +}; + +struct OutputLayoutOverrideParser + : public llvm::cl::parser> { +public: + OutputLayoutOverrideParser(llvm::cl::Option &opt) + : llvm::cl::parser>(opt) {} + + bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value); + + static std::string + toString(const llvm::StringMap &); + + static void print(llvm::raw_ostream &os, + const llvm::StringMap &value); +}; + +struct InputLayoutOverrideParser + : public llvm::cl::parser> { +public: + InputLayoutOverrideParser(llvm::cl::Option &opt) + : llvm::cl::parser>(opt) {} + + bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value); + + static std::string + toString(const llvm::StringMap &); + + static void print(llvm::raw_ostream &os, + const llvm::StringMap &value); +}; + +} // namespace mlir::tt::ttnn + +#endif // TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H diff --git a/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h b/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h new file mode 100644 index 0000000000..2dc83388d1 --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTNN_UTILS_TRANSFORMUTILS_H +#define TTMLIR_DIALECT_TTNN_UTILS_TRANSFORMUTILS_H + +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" + +namespace mlir::tt::ttnn::utils { +// Get or insert device for the given operation. +mlir::Value getOrInsertDevice(mlir::PatternRewriter &rewriter, + mlir::Operation *op); +} // namespace mlir::tt::ttnn::utils + +#endif diff --git a/include/ttmlir/Dialect/TTNN/Utils/Utils.h b/include/ttmlir/Dialect/TTNN/Utils/Utils.h index a6e10c0991..2c4b7a2508 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/Utils.h +++ b/include/ttmlir/Dialect/TTNN/Utils/Utils.h @@ -5,10 +5,14 @@ #ifndef TTMLIR_DIALECT_TTNN_UTILS_UTILS_H #define TTMLIR_DIALECT_TTNN_UTILS_UTILS_H +#include + #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" +#include "mlir/IR/BuiltinTypes.h" + namespace mlir::tt::ttnn::utils { // Map tt::MemorySpace to ttnn::BufferType @@ -31,13 +35,18 @@ mlir::tt::TensorMemoryLayout toTTTensorMemoryLayout( mlir::tt::MemorySpace toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType); -DataType getDataTypeFromMemRef(mlir::MemRefType memref); - +// Get Layout from MemRefType +// Layout getLayoutFromMemRef(mlir::MemRefType memref); mlir::Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context, DataType dtype); +// Helper method to create a RankedTensorType with the given encoding +RankedTensorType +createRankedTensorTypeWithEncoding(RankedTensorType tensorType, + ttnn::TTNNLayoutAttr encoding); + } // namespace mlir::tt::ttnn::utils #endif // TTMLIR_DIALECT_TTNN_UTILS_UTILS_H diff --git a/include/ttmlir/OpModel/TTNN/TTNNOpModel.h b/include/ttmlir/OpModel/TTNN/TTNNOpModel.h new file mode 100644 index 0000000000..31ac149849 --- /dev/null +++ b/include/ttmlir/OpModel/TTNN/TTNNOpModel.h @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H +#define TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H + +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" + +#include + +namespace mlir::tt::op_model::ttnn { + +struct ReluOpInterface { + static bool isLegal(const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); + + static std::tuple + getOpL1Usage(const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); +}; + +} // namespace mlir::tt::op_model::ttnn +#endif // TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H diff --git a/include/ttmlir/Scheduler/Scheduler.h b/include/ttmlir/Scheduler/Scheduler.h index 817271fdc9..5d41163311 100644 --- a/include/ttmlir/Scheduler/Scheduler.h +++ b/include/ttmlir/Scheduler/Scheduler.h @@ -23,6 +23,10 @@ class Scheduler { // Method to get the next set of schedulable operations llvm::SmallVector getScheduleableOps(); + // Method to check if an operation is either a TTIR op or a + // TTNN scheduleable op. + bool isTTShedulableOp(mlir::Operation *op); + // Method to check if an operation can be scheduled bool canSchedule(mlir::Operation *op); diff --git a/include/ttmlir/Target/Common/types.fbs b/include/ttmlir/Target/Common/types.fbs index 2d67ee1d1c..3e7ed425f7 100644 --- a/include/ttmlir/Target/Common/types.fbs +++ b/include/ttmlir/Target/Common/types.fbs @@ -11,67 +11,67 @@ struct Dim2dRange { } enum Arch: uint { - Grayskull = 0, - Wormhole_b0 = 1, - Blackhole = 2, + Grayskull, + Wormhole_b0, + Blackhole } enum DataType: uint16 { - Float32 = 0, - Float16 = 1, - BFloat16 = 2, - BFP_Float8 = 3, - BFP_BFloat8 = 4, - BFP_Float4 = 5, - BFP_BFloat4 = 6, - BFP_Float2 = 7, - BFP_BFloat2 = 8, - UInt32 = 9, - UInt16 = 10, - UInt8 = 11, + Float32, + Float16, + BFloat16, + BFP_Float8, + BFP_BFloat8, + BFP_Float4, + BFP_BFloat4, + BFP_Float2, + BFP_BFloat2, + UInt32, + UInt16, + UInt8, } enum OOBVal: ushort { - Undef = 0, - Zero = 1, - One = 2, - Inf = 3, - NegInf = 4, + Undef, + Zero, + One, + Inf, + NegInf, } enum MemorySpace: ushort { - System = 0, - SystemMMIO = 1, - DeviceDRAM = 2, - DeviceL1 = 3, + System, + SystemMMIO, + DeviceDRAM, + DeviceL1, } enum ChipCapability: uint32 (bit_flags) { - PCIE = 0, - HostMMIO = 1, + PCIE, + HostMMIO, } enum TensorMemoryLayout: ushort { - None = 0, - Interleaved = 1, - SingleBank = 2, - HeightSharded = 3, - WidthSharded = 4, - BlockSharded = 5, + None, + Interleaved, + SingleBank, + HeightSharded, + WidthSharded, + BlockSharded, } enum TensorLayout: ushort { - RowMajor = 0, - Tile = 1, - Invalid = 2, + RowMajor, + Tile, + Invalid, } enum BufferType: ushort { - DRAM = 0, - L1 = 1, - SystemMemory = 2, - L1Small = 3, - Trace = 4, + DRAM, + L1, + SystemMemory, + L1Small, + Trace, } // TODO (#620): Add other fields like core_ranges, shard orientation etc. @@ -197,8 +197,8 @@ table ChipPhysicalCores { enum CPURole: uint8 { - Host = 0, - Device = 1, + Host, + Device, } table CPUDesc { @@ -223,9 +223,11 @@ table EventRef { global_id: uint32; } +// Explicit non-sequential enumeration copied over from tt-metal definition of +// `enum class MathFidelity`. enum MathFidelity : uint8 { - LoFi = 0, - HiFi2 = 2, - HiFi3 = 3, - HiFi4 = 4, + LoFi = 0, + HiFi2 = 2, + HiFi3 = 3, + HiFi4 = 4, } diff --git a/include/ttmlir/Target/TTMetal/program.fbs b/include/ttmlir/Target/TTMetal/program.fbs index 4fcf966020..52451234b1 100644 --- a/include/ttmlir/Target/TTMetal/program.fbs +++ b/include/ttmlir/Target/TTMetal/program.fbs @@ -3,18 +3,18 @@ include "Common/types.fbs"; namespace tt.target.metal; enum NocIndex : ushort { - Noc0 = 0, - Noc1 = 1, + Noc0, + Noc1, } enum EthType : ushort { - Sender = 0, - Receiver = 1, + Sender, + Receiver, } enum UnpackToDestMode : uint8 { - UnpackToDestFp32 = 0, - Default = 1, + UnpackToDestFp32, + Default, } table NocConfig { @@ -45,17 +45,17 @@ table KernelSource { } enum BinaryType : ushort { - BRISC = 0, - NCRISC = 1, - TRISC0 = 2, - TRISC1 = 3, - TRISC2 = 4, - ERISC = 5, + BRISC, + NCRISC, + TRISC0, + TRISC1, + TRISC2, + ERISC, } enum CoreType : ushort { - WORKER = 0, - ETH = 1, + WORKER, + ETH, } table KernelBinary { diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index ec493e6496..5644c970d8 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -37,6 +37,19 @@ table ToDeviceOp { out: tt.target.TensorRef; } +table UpdateCacheOp { + cache: tt.target.TensorRef; + input: tt.target.TensorRef; + update_index: tt.target.TensorRef; + batch_offset: uint32; +} + +table FillCacheOp { + cache: tt.target.TensorRef; + input: tt.target.TensorRef; + batch_offset: uint32; +} + table FromDeviceOp { in: tt.target.TensorRef; out: tt.target.TensorRef; @@ -61,47 +74,60 @@ table FullOp { out: tt.target.TensorRef; } +table ArangeOp { + start: float; + end: float; + step: float; + dtype: tt.target.DataType = null; // optional + device: tt.target.DeviceRef; // optional + memcfg: tt.target.MemoryConfigDesc; // optional + out: tt.target.TensorRef; +} + enum EltwiseOpType: uint32 { - Add = 0, - Multiply = 1, - Subtract = 2, - Relu = 3, - GreaterEqual = 4, - Sqrt = 5, - Div = 6, - Sigmoid = 7, - Reciprocal = 8, - Exp = 9, - Maximum = 10, - Abs = 11, - Neg = 12, - Rsqrt = 13, - Typecast = 14, - Equal = 15, - NotEqual = 16, - LessEqual = 17, - LessThan = 18, - GreaterThan = 19, - LogicalAnd = 20, - LogicalOr = 21, - LogicalNot = 22, - Cbrt = 23, - Minimum = 24, - Ceil = 25, - Sin = 26, - Cos = 27, - Log = 28, - Log1p = 29, - Expm1 = 30, - Sign = 31, - Remainder = 32, - IsFinite = 33, - Floor = 34, - Where = 35, - Gelu = 36, - LogicalXor = 37, - Clamp = 38, - LeakyRelu = 39, + Add, + Multiply, + Subtract, + Relu, + GreaterEqual, + Sqrt, + Div, + Sigmoid, + Reciprocal, + Exp, + Maximum, + Abs, + Neg, + Rsqrt, + Typecast, + Equal, + NotEqual, + LessEqual, + LessThan, + GreaterThan, + LogicalAnd, + LogicalOr, + LogicalNot, + Cbrt, + Minimum, + Ceil, + Sin, + Cos, + Log, + Log1p, + Expm1, + Sign, + Remainder, + IsFinite, + Floor, + Where, + Gelu, + LogicalXor, + Clamp, + LeakyRelu, + Scatter, + Tan, + Tanh } table ClampOpParams { @@ -126,9 +152,9 @@ table EltwiseOp { } enum ReductionOpType: uint32 { - Sum = 0, - Mean = 1, - Max = 2, + Sum, + Mean, + Max, } table ReductionOp { @@ -178,6 +204,13 @@ table SliceOp { step: [int64]; } +table LinearOp { + in0: tt.target.TensorRef; + in1: tt.target.TensorRef; + bias: tt.target.TensorRef; + out: tt.target.TensorRef; +} + // ANCHOR: adding_an_op_matmul_fbs table MatmulOp { in0: tt.target.TensorRef; @@ -249,6 +282,7 @@ union OpType { EmptyOp, FullOp, EltwiseOp, + LinearOp, MatmulOp, ReductionOp, EmbeddingOp, @@ -261,11 +295,15 @@ union OpType { MaxPool2dOp, DeallocateOp, AllGatherOp, + ArangeOp, + UpdateCacheOp, + FillCacheOp, } table Operation { type: OpType; debug_info: string; + loc_info: string; } table Program { diff --git a/include/ttmlir/Target/TTNN/utils.h b/include/ttmlir/Target/TTNN/utils.h index e3f642a2d9..201cc1ee3b 100644 --- a/include/ttmlir/Target/TTNN/utils.h +++ b/include/ttmlir/Target/TTNN/utils.h @@ -26,8 +26,6 @@ ::tt::target::TensorMemoryLayout toTargetTensorMemoryLayout( return ::tt::target::TensorMemoryLayout::WidthSharded; case ::mlir::tt::ttnn::TensorMemoryLayout::BlockSharded: return ::tt::target::TensorMemoryLayout::BlockSharded; - case ::mlir::tt::ttnn::TensorMemoryLayout::None: - return ::tt::target::TensorMemoryLayout::None; } llvm_unreachable("Unsupported TensorMemoryLayout"); diff --git a/include/ttmlir/Target/Utils/FuncOpToProgram.h b/include/ttmlir/Target/Utils/FuncOpToProgram.h index d9e8d98207..a28f2f5e9a 100644 --- a/include/ttmlir/Target/Utils/FuncOpToProgram.h +++ b/include/ttmlir/Target/Utils/FuncOpToProgram.h @@ -31,6 +31,13 @@ inline std::string getOpDebugString(mlir::Operation *op, return str; }; +inline std::string getOpLocInfo(mlir::Operation *op) { + std::string str; + llvm::raw_string_ostream os(str); + op->getLoc().print(os); + return str; +} + inline Value getOperandThroughDPSOps(Value value) { auto *op = value.getDefiningOp(); if (!op) { @@ -76,7 +83,8 @@ Program funcOpToProgram(FlatbufferObjectCache &cache, func::FuncOp entry, } } else { std::string debugStr = getOpDebugString(op, printFlags); - program.ops.push_back(fn(cache, op, debugStr)); + std::string locInfo = getOpLocInfo(op); + program.ops.push_back(fn(cache, op, debugStr, locInfo)); } }); diff --git a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h index ac23b9bb0d..cb9439d978 100644 --- a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h +++ b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h @@ -18,8 +18,9 @@ namespace mlir::tt { flatbuffers::Offset<::tt::target::LayoutDesc> -layoutAttrToFlatbuffer(FlatbufferObjectCache &cache, LayoutAttr attr, - ArrayRef logicalShape, DeviceAttr deviceAttr); +metalLayoutAttrToFlatbuffer(FlatbufferObjectCache &cache, MetalLayoutAttr attr, + ArrayRef logicalShape, + DeviceAttr deviceAttr); flatbuffers::Offset<::tt::target::LayoutDesc> ttnnLayoutAttrToFlatbuffer( FlatbufferObjectCache &cache, ttnn::TTNNLayoutAttr attr, @@ -438,9 +439,9 @@ toFlatbuffer(FlatbufferObjectCache &cache, ElementsAttr elementsAttr) { inline flatbuffers::Offset<::tt::target::LayoutDesc> encodingToFlatbuffer(FlatbufferObjectCache &cache, Attribute attr, ArrayRef logicalShape, DeviceAttr deviceAttr) { - if (isa(attr)) { - return layoutAttrToFlatbuffer(cache, cast(attr), logicalShape, - deviceAttr); + if (isa(attr)) { + return metalLayoutAttrToFlatbuffer(cache, cast(attr), + logicalShape, deviceAttr); } assert(isa(attr) && "unsupported layout attr"); @@ -478,7 +479,11 @@ toDebugInfo(::flatbuffers::FlatBufferBuilder &fbb, std::string const &name, ModuleOp module) { std::string source; llvm::raw_string_ostream os(source); - module->print(os); + + mlir::OpPrintingFlags flags; + flags.enableDebugInfo(); // Enable the loc dumping + module->print(os, flags); + return ::tt::target::CreateMLIRDirect(fbb, name.c_str(), source.c_str()); } } // namespace mlir::tt diff --git a/include/ttmlir/Utils.h b/include/ttmlir/Utils.h index bcf836741a..49dad79e5e 100644 --- a/include/ttmlir/Utils.h +++ b/include/ttmlir/Utils.h @@ -127,6 +127,11 @@ inline MlirAttribute wrapArrayOfMlirAttributesAsAttribute( return wrap(mlir::ArrayAttr::get(unwrap(ctx), unwrappedAttributesArray)); } +// Checks if the type of the given `mlir::Value` is a ranked tensor type. +inline bool isRankedTensor(mlir::Value v) { + return mlir::isa(v.getType()); +} + } // namespace ttmlir::utils #endif diff --git a/lib/CAPI/TTAttrs.cpp b/lib/CAPI/TTAttrs.cpp index 40a3ada6fb..c329f41d56 100644 --- a/lib/CAPI/TTAttrs.cpp +++ b/lib/CAPI/TTAttrs.cpp @@ -119,15 +119,15 @@ MlirAttribute ttmlirTTSystemDescAttrGet( chipCapabilitiesUnwrapped, chipCoordsUnwrapped, chipChannelsUnwrapped)); } -MlirAttribute ttmlirTTLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, - unsigned oobVal, MlirAttribute grid, - MlirType memref, unsigned memLayout) { +MlirAttribute ttmlirTTMetalLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, + unsigned oobVal, MlirAttribute grid, + MlirType memref, unsigned memLayout) { mlir::AffineMap affineMap = mlir::AffineMap::getFromOpaquePointer(linear.ptr); - return wrap(LayoutAttr::get(unwrap(ctx), affineMap, - static_cast(oobVal), - mlir::cast(unwrap(grid)), - mlir::cast(unwrap(memref)), - static_cast(memLayout))); + return wrap(MetalLayoutAttr::get(unwrap(ctx), affineMap, + static_cast(oobVal), + mlir::cast(unwrap(grid)), + mlir::cast(unwrap(memref)), + static_cast(memLayout))); } MlirAttribute ttmlirTTMemorySpaceAttrGet(MlirContext ctx, @@ -219,4 +219,8 @@ MlirAttribute ttmlirTTChipPhysicalCoresAttrGet( ethVec, ethInactiveVec)); } +MlirAttribute ttmlirTTCoreCoordAttrGet(MlirContext ctx, int64_t y, int64_t x) { + return wrap(CoreCoordAttr::get(unwrap(ctx), y, x)); +} + } // namespace mlir::tt diff --git a/lib/CAPI/TTNNAttrs.cpp b/lib/CAPI/TTNNAttrs.cpp index 0fb1066cb8..467d8c0044 100644 --- a/lib/CAPI/TTNNAttrs.cpp +++ b/lib/CAPI/TTNNAttrs.cpp @@ -53,10 +53,9 @@ MlirAttribute ttmlirTTNNMemoryConfigAttrGet( MlirContext ctx, MlirAttribute tensorMemoryLayoutAttr, MlirAttribute bufferTypeAttr, MlirAttribute shardSpecAttr) { return wrap(MemoryConfigAttr::get( - unwrap(ctx), - mlir::cast(unwrap(tensorMemoryLayoutAttr)), - mlir::cast(unwrap(bufferTypeAttr)), - mlir::cast(unwrap(shardSpecAttr)))); + unwrap(ctx), mlir::cast(unwrap(bufferTypeAttr)), + mlir::cast(unwrap(shardSpecAttr)), + mlir::cast(unwrap(tensorMemoryLayoutAttr)))); } MlirAttribute ttmlirTTNNShapeAttrGet(MlirContext ctx, int64_t *shape, @@ -69,4 +68,25 @@ MlirAttribute ttmlirTTNNMeshShapeAttrGet(MlirContext ctx, int64_t y, return wrap(MeshShapeAttr::get(unwrap(ctx), y, x)); } +// Get layout TTNNLayout attribute +// +// param ctx: mlir context +// param linear Affine map for mapping tensor from logical to physical space +// param grid Grid of cores where tensor is mapped to +// param memref Memref which holds shard size, shard scalar type and memory +// param memLayout Memory layout of the tensor +MlirAttribute ttmlirTTNNTTNNLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, + MlirAttribute grid, MlirType memref, + unsigned *memLayout = nullptr) { + mlir::AffineMap affineMap = mlir::AffineMap::getFromOpaquePointer(linear.ptr); + TensorMemoryLayoutAttr memLayoutAttr; + if (memLayout) { + memLayoutAttr = TensorMemoryLayoutAttr::get( + unwrap(ctx), static_cast(*memLayout)); + } + return wrap(TTNNLayoutAttr::get( + unwrap(ctx), affineMap, mlir::cast(unwrap(grid)), + mlir::cast(unwrap(memref)), memLayoutAttr)); +} + } // namespace mlir::tt::ttnn diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index c3dc3a4b71..881d6545dc 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo) include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo-build) +add_subdirectory(OpModel) add_subdirectory(CAPI) add_subdirectory(Conversion) add_subdirectory(Dialect) diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 1ec8556cff..96ef7ca017 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/Region.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" @@ -120,7 +121,9 @@ class StableHLOToTTIRReduceOpConversionPattern srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); mlir::ArrayAttr dimArg = rewriter.getArrayAttr(SmallVector( - 1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0]))); + 1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr().size() > 0 + ? adaptor.getDimensionsAttr()[0] + : 1))); // If someone changes definition of TTIR_ReductionOp this constant will // become outdated, but I currently see no way to get this info (without @@ -279,30 +282,81 @@ class StableHLOToTTIRDotGeneralOpConversionPattern ::mlir::stablehlo::DotDimensionNumbersAttr dimensions = adaptor.getDotDimensionNumbers(); - if (dimensions.getLhsContractingDimensions().empty() || - dimensions.getRhsContractingDimensions().empty()) { - return rewriter.notifyMatchFailure(srcOp, - "Contracting dimension is missing."); + if (dimensions.getLhsContractingDimensions().size() != 1 || + dimensions.getRhsContractingDimensions().size() != 1) { + return rewriter.notifyMatchFailure( + srcOp, + "LHS and RHS must have exactly 1 contracting dimension each. " + "Received LHS contracting dims: " + + std::to_string(dimensions.getLhsContractingDimensions().size()) + + ", RHS contracting dims: " + + std::to_string(dimensions.getRhsContractingDimensions().size())); + } + + // Use negative indexing to determine if this is a valid matmul since math + // is done over the final two dimensions. + int64_t lhsContractingDim = dimensions.getLhsContractingDimensions()[0] - + srcOp.getLhs().getType().getRank(); + int64_t rhsContractingDim = dimensions.getRhsContractingDimensions()[0] - + srcOp.getRhs().getType().getRank(); + + if (lhsContractingDim != -1) { + return rewriter.notifyMatchFailure( + srcOp, "Only support contracting dimensions that correspond to valid " + "matmuls. LHS contracting dimension must be " + + std::to_string(srcOp.getLhs().getType().getRank() - 1) + + ". Got " + std::to_string(lhsContractingDim)); } - if (dimensions.getLhsContractingDimensions()[0] != 1) { + if (rhsContractingDim != -2) { return rewriter.notifyMatchFailure( - srcOp, "Only non-transposed matmul is currently supported in TTIR."); + srcOp, "Only support contracting dimensions that correspond to valid " + "matmuls. RHS contracting dimension must be " + + std::to_string(srcOp.getRhs().getType().getRank() - 2) + + ". Got " + std::to_string(rhsContractingDim)); } - if (dimensions.getRhsContractingDimensions()[0] != 0) { + if (dimensions.getLhsBatchingDimensions() != + dimensions.getRhsBatchingDimensions()) { return rewriter.notifyMatchFailure( - srcOp, "Only non-transposed matmul is currently supported in TTIR."); + srcOp, "LHS and RHS must have same batching dimensions."); } - if (!dimensions.getLhsBatchingDimensions().empty()) { + // For the RHS, all dimensions which are not the row and column dimensions + // must be 1 OR they must be equal to the corresponding dimension in the + // LHS. If the RHS has less dimensions than the LHS we will assume that the + // missing dimensions are 1. + + auto lhsShape = srcOp.getLhs().getType().getShape().vec(); + auto rhsShape = srcOp.getRhs().getType().getShape().vec(); + + if (rhsShape.size() > lhsShape.size()) { return rewriter.notifyMatchFailure( - srcOp, "Only non-transposed matmul is currently supported in TTIR."); + srcOp, "RHS must not be a higher rank than LHS."); + } + + while (rhsShape.size() < lhsShape.size()) { + rhsShape.insert(rhsShape.begin(), 1); + } + + // Need only to check dims to the left of dim -2 on the RHS + bool allOnes = true; + bool mismatchedDims = false; + for (int32_t i = rhsShape.size() - 3; i >= 0; i--) { + if (rhsShape[i] != 1) { + allOnes = false; + } + + if (rhsShape[i] != lhsShape[i]) { + mismatchedDims = true; + } } - if (!dimensions.getRhsBatchingDimensions().empty()) { + if (mismatchedDims && !allOnes) { return rewriter.notifyMatchFailure( - srcOp, "Only non-transposed matmul is currently supported in TTIR."); + srcOp, "All dimensions in the RHS that are not the row and column " + "dimensions must be 1 OR they must all be equal to the " + "corresponding dimensions in the LHS."); } return success(); @@ -799,7 +853,7 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern llvm::SmallVector broadcastedShape; auto srcType = - getTypeConverter()->convertType(srcOp.getOperand().getType()); + getTypeConverter()->convertType(adaptor.getOperand().getType()); auto inputShape = mlir::cast(srcType).getShape(); auto outputShape = mlir::cast(srcType).getShape(); @@ -945,8 +999,8 @@ class StableHLOToTTIRConcatOpConversionPattern "ConcatOp dimension is too large."); } - auto rankedTensorType = - mlir::dyn_cast(srcOp.getOperand(0).getType()); + auto rankedTensorType = mlir::dyn_cast( + adaptor.getOperands()[0].getType()); if (static_cast(adaptor.getDimension()) >= rankedTensorType.getRank()) { return rewriter.notifyMatchFailure(srcOp, @@ -1006,6 +1060,440 @@ class StableHLOToTTIROpLogicalOpConversionPattern } }; +template +LogicalResult getReduceType(SrcOpTy &srcOp, ReduceType &reduceType) { + if constexpr (!std::is_same::value) { + return failure(); + } + // Check operations in the first block and determine reduce type for now + // TODO(wooseoklee): This pattern matching mechanism may need to be updated as + // we see complicated patterns of reduce block in the future. + auto &block = srcOp.getRegion().front(); + for (Operation &op : block) { + if (isa(op)) { + reduceType = ReduceType::Sum; + return success(); + } + if (isa(op)) { + reduceType = ReduceType::Max; + return success(); + } + if (isa(op)) { + reduceType = ReduceType::Min; + return success(); + } + } + // Other reduce types are currently not supported + return failure(); +} + +// StalbeHLO spec.md defines following channel type for ccl ops +enum StableHLOChannelType { + // CHANNEL_TYPE_INVALID = 0 : Invalid primitive type to serve as + // default. + kChannelTypeInvalid = 0, + // DEVICE_TO_DEVICE = 1 : A channel for sending data between + // devices. + kChannelTypeDeviceToDevice = 1, + // DEVICE_TO_HOST = 2 : A channel for sending data from the + // device to the host. Can only be used with a Send operation. + kChannelTypeDeviceToHost = 2, + // HOST_TO_DEVICE = 3 : A channel for sending data from the host to + // the device. Can only be used with a Recv operation. + kChannelTypeHostToDevice = 3, +}; + +class StableHLOToTTIRAllReduceOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::AllReduceOp srcOp, + mlir::stablehlo::AllReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Check legality of the operation + LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter); + if (failed(err)) { + return err; + } + + // Create the output tensor type based on inputs + auto outputType = mlir::cast( + getTypeConverter()->convertType(srcOp.getResult(0).getType())); + + // Create an empty output tensor with the computed shape + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + SmallVector ttirTypes; + if (failed(this->getTypeConverter()->convertTypes(srcOp->getResultTypes(), + ttirTypes))) { + return failure(); + } + + auto ttirOperands = srcOp.getOperandsMutable(); + ttirOperands.append(ValueRange(outputTensor)); + + SmallVector srcAttrs = to_vector(srcOp->getAttrs()); + SmallVector ttirAttrs; + for (auto srcAttr : srcAttrs) { + StringAttr srcName = srcAttr.getName(); + if (srcName == "channel_handle") { + auto srcChannelHandleAttr = + dyn_cast(srcAttr.getValue()); + if (!srcChannelHandleAttr) { + return failure(); + } + + // channelType is supposed to be DEVICE_TO_DEVICE for CCL ops. + // Currently, we ensure if it is DEVICE_TO_DEVICE commmuincaiton. + // Consider preserving this information in the future if the attribute + // is non-DEVICE_TO_DEVICE values. + auto channelType = static_cast(srcChannelHandleAttr.getType()); + if (channelType != kChannelTypeDeviceToDevice) { + return failure(); + } + + IntegerAttr channelHandleAttr = rewriter.getSI32IntegerAttr( + static_cast(srcChannelHandleAttr.getHandle())); + if (!channelHandleAttr) { + return failure(); + } + ttirAttrs.push_back({srcName, channelHandleAttr}); + } else { + ttirAttrs.push_back(srcAttr); + } + } + + // Algorithm here is to search for the first non-one working dimension + auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape(); + size_t dim = 0; + for (auto s : replicaGroupsShape) { + if (s != 1) { + break; + } + ++dim; + } + if (dim > replicaGroupsShape.size()) { + // all one shape, then select the fastest dim + dim = replicaGroupsShape.size(); + } + StringAttr dimName = StringAttr::get(this->getContext(), "dim"); + IntegerAttr dimAttr = + rewriter.getSI32IntegerAttr(static_cast(dim)); + ttirAttrs.push_back({dimName, dimAttr}); + + // Parse computation in region and add it to ttirAttrs + ReduceType reduceType; + if (failed(getReduceType(srcOp, reduceType))) { + return rewriter.notifyMatchFailure( + srcOp, "AllReduceOp cannot specify reduce type."); + } + StringAttr reduceTypeAttrName = + StringAttr::get(this->getContext(), "reduce_type"); + Attribute reduceTypeAttr = rewriter.getAttr(reduceType); + ttirAttrs.push_back({reduceTypeAttrName, reduceTypeAttr}); + + StringAttr operationConstraintAttrName = + StringAttr::get(this->getContext(), "operand_constraints"); + Attribute operationConstraintAttr = rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile))); + ttirAttrs.push_back({operationConstraintAttrName, operationConstraintAttr}); + + auto ttirAllReduceOp = rewriter.create( + srcOp.getLoc(), ttirTypes, ValueRange(ttirOperands.getAsOperandRange()), + ttirAttrs); + + rewriter.replaceOp(srcOp, ttirAllReduceOp); + + return success(); + } + +private: + LogicalResult + checkBasicLegality(mlir::stablehlo::AllReduceOp &srcOp, + mlir::stablehlo::AllReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (srcOp.getOperands().empty() || srcOp.getOperands().size() > 1) { + return rewriter.notifyMatchFailure( + srcOp, "AllReduceOp must have one input/output for now."); + } + + return success(); + } +}; // namespace + +class StableHLOToTTIRCustomCallOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::CustomCallOp srcOp, + mlir::stablehlo::CustomCallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Check legality of the operation + LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter); + if (failed(err)) { + return err; + } + + const std::string kShardingTarget = "Sharding"; + const std::string kSPMDFullToShardShapeTarget = "SPMDFullToShardShape"; + const std::string kSPMDShardToFullShapeTarget = "SPMDShardToFullShape"; + + auto callTargetName = adaptor.getCallTargetNameAttr(); + + // Currently stablehlo.custom_call with following functions from + // jax/openxla are supported + if (callTargetName != kShardingTarget && + callTargetName != kSPMDFullToShardShapeTarget && + callTargetName != kSPMDShardToFullShapeTarget) { + return failure(); + } + + auto shardingAttr = dyn_cast_or_null( + adaptor.getAttributes().get("mhlo.sharding")); + if (!shardingAttr) { + return failure(); + } + StringRef shardingStr = shardingAttr.getValue(); + if (!shardingStr.consume_front("{") || !shardingStr.consume_back("}")) { + return failure(); + } + SmallVector shardingStrAttrs; + shardingStr.split(shardingStrAttrs, " "); + struct ShardAttrValue shardAttrValue; + if (failed(parseShardingAttr(rewriter, shardingStrAttrs, shardAttrValue))) { + return failure(); + } + + if (callTargetName == kSPMDFullToShardShapeTarget) { + Operation *shardingOp = srcOp->getOperand(0).getDefiningOp(); + if (!shardingOp) { + return rewriter.notifyMatchFailure( + srcOp, "requires operand to be defined by an op"); + } + + // TODO(wooseoklee): a bit rough approach here to match output dim + shardingOp->getResult(0).setType(srcOp->getResult(0).getType()); + srcOp.getResult(0).replaceAllUsesWith(shardingOp->getResult(0)); + rewriter.eraseOp(srcOp); + } else if (callTargetName == kSPMDShardToFullShapeTarget) { + Operation *shardingOp = srcOp->getOperand(0).getDefiningOp(); + if (!shardingOp) { + return rewriter.notifyMatchFailure( + srcOp, "requires operand to be defined by an op"); + } + + // Create the output tensor type based on inputs + auto outputType = mlir::cast( + getTypeConverter()->convertType(srcOp->getResult(0).getType())); + + // Create an empty output tensor with the computed shape + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + SmallVector outputTypes; + if (failed(this->getTypeConverter()->convertTypes(srcOp->getResultTypes(), + outputTypes))) { + return failure(); + } + + shardAttrValue.shardDirection = mlir::tt::MeshShardDirection::ShardToFull; + if (failed(createMeshShardOp(srcOp, adaptor, outputTensor, outputTypes, + shardAttrValue, rewriter))) { + return failure(); + } + } else if (callTargetName == kShardingTarget) { + if (shardAttrValue.shardType == mlir::tt::MeshShardType::Manual) { + // "manual" sharding indicates match between input/output tensor shape + // and no sharding is required. + srcOp.getResult(0).replaceAllUsesWith(srcOp->getOperand(0)); + rewriter.eraseOp(srcOp); + } else { + auto *user = *srcOp.getResult(0).user_begin(); + auto userOp = dyn_cast_or_null(user); + if (!userOp) { + return failure(); + } + + // Create the output tensor type based on inputs + auto outputType = mlir::cast( + getTypeConverter()->convertType(userOp->getResult(0).getType())); + + // Create an empty output tensor with the computed shape + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + SmallVector outputTypes; + if (failed(this->getTypeConverter()->convertTypes( + userOp->getResultTypes(), outputTypes))) { + return failure(); + } + + shardAttrValue.shardDirection = + mlir::tt::MeshShardDirection::FullToShard; + if (failed(createMeshShardOp(srcOp, adaptor, outputTensor, outputTypes, + shardAttrValue, rewriter))) { + return failure(); + } + } + } + return success(); + } + +private: + struct ShardAttrValue { + mlir::tt::MeshShardDirection shardDirection; + mlir::tt::MeshShardType shardType; + bool lastTileDimReplicate; + std::vector shardShape; + }; + + // OpenXLA has its own lexer, but we will use simple string-based parser here + // This parsing is mainly based on "Sharding Attribute" section in + // https://github.com/sdasgup3/stablehlo/blob/80082431d1af0933e6202ecc8a6f8801e039235b/docs/spec.md + LogicalResult parseShardingAttr(ConversionPatternRewriter &rewriter, + SmallVector shardingStrAttrs, + struct ShardAttrValue &shardAttrValue) const { + MeshShardType shardType = mlir::tt::MeshShardType::Manual; + bool lastTileDimReplicate = false; + for (auto str : shardingStrAttrs) { + if (str.contains("replicated")) { + assert(shardType == mlir::tt::MeshShardType::Manual && + "Fail to parse sharding info."); + // replicated: all devices have whole data + shardType = mlir::tt::MeshShardType::Replicate; + shardAttrValue.shardShape.push_back(1); + } else if (str.contains("maximal")) { + assert(shardType == mlir::tt::MeshShardType::Manual && + "Fail to parse sharding info."); + // maximal: one device has whole data + shardType = mlir::tt::MeshShardType::Maximal; + shardAttrValue.shardShape.push_back(1); + } else if (str.contains("device=")) { + // maximal should followed by "device" to put data on + assert(shardType == mlir::tt::MeshShardType::Maximal && + "Fail to parse sharding info."); + int64_t d; + if (!str.consume_front("device=")) { + return failure(); + } + if (str.getAsInteger(10, d)) { + return failure(); + } + shardAttrValue.shardShape.push_back(d); + } else if (str.contains("manual")) { + assert(shardType == mlir::tt::MeshShardType::Manual && + "Fail to parse sharding info."); + // manual: already sharded, so no action is needed + assert(!lastTileDimReplicate && + "last time dim duplicate option shouldn't be set here."); + shardAttrValue.shardShape.push_back(1); + } else if (str.contains("devices=")) { + // other: "devices" detail sharding plan + assert(shardType == mlir::tt::MeshShardType::Manual && + "Fail to parse sharding info."); + shardType = mlir::tt::MeshShardType::Devices; + if (!str.consume_front("devices=")) { + return failure(); + } + auto [devicesStr, restStr] = str.split("<="); + // parse devices ex. [4,2,1] + if (!devicesStr.consume_front("[") || !devicesStr.consume_back("]")) { + return failure(); + } + SmallVector dimsStr; + devicesStr.split(dimsStr, ","); + for (auto dim : dimsStr) { + int64_t d; + if (dim.getAsInteger(10, d)) { + return failure(); + } + shardAttrValue.shardShape.push_back(d); + } + } else if (str.contains("last_tile_dim_replicate")) { + assert(shardType == mlir::tt::MeshShardType::Devices && + "Fail to parse sharding info."); + // other: replicate last tile dim + lastTileDimReplicate = true; + } + } + shardAttrValue.shardType = shardType; + shardAttrValue.lastTileDimReplicate = lastTileDimReplicate; + return success(); + } + + LogicalResult + createMeshShardOp(mlir::stablehlo::CustomCallOp &srcOp, + mlir::stablehlo::CustomCallOp::Adaptor adaptor, + tensor::EmptyOp &outputTensor, + SmallVector &outputTypes, + ShardAttrValue &shardAttrValue, + ConversionPatternRewriter &rewriter) const { + + auto meshShardOperands = srcOp.getInputsMutable(); + meshShardOperands.append(ValueRange(outputTensor)); + SmallVector meshShardAttrs; + + StringAttr shardTypeAttrName = rewriter.getStringAttr("shard_type"); + Attribute shardTypeAttr = + rewriter.getAttr(shardAttrValue.shardType); + meshShardAttrs.push_back({shardTypeAttrName, shardTypeAttr}); + + StringAttr shardDirectionAttrName = + rewriter.getStringAttr("shard_direction"); + Attribute shardDirectionAttr = + rewriter.getAttr(shardAttrValue.shardDirection); + meshShardAttrs.push_back({shardDirectionAttrName, shardDirectionAttr}); + + StringAttr shardShapeAttrName = rewriter.getStringAttr("shard_shape"); + if (shardAttrValue.lastTileDimReplicate) { + shardAttrValue.shardShape.pop_back(); + } + GridAttr shardShape = + GridAttr::get(this->getContext(), shardAttrValue.shardShape); + meshShardAttrs.push_back({shardShapeAttrName, shardShape}); + + StringAttr operationConstraintAttrName = + StringAttr::get(this->getContext(), "operand_constraints"); + Attribute operationConstraintAttr = rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::SystemScalar))); + meshShardAttrs.push_back( + {operationConstraintAttrName, operationConstraintAttr}); + + auto meshShardOp = rewriter.create( + srcOp.getLoc(), outputTypes, + ValueRange(meshShardOperands.getAsOperandRange()), meshShardAttrs); + rewriter.replaceOp(srcOp, meshShardOp); + + return success(); + } + + LogicalResult + checkBasicLegality(mlir::stablehlo::CustomCallOp &srcOp, + mlir::stablehlo::CustomCallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Expect single input/output, otherwise do not convert + if (adaptor.getInputs().size() != 1 && srcOp->getResults().size() != 1) { + return failure(); + } + + return success(); + } +}; // namespace + class StableHLOToTTIRSliceOpConversionPattern : public OpConversionPattern { @@ -1134,8 +1622,8 @@ class StableHLOToTTIRGatherOpConversionPattern auto dimensionNumbers = srcOp.getDimensionNumbers(); rewriter.replaceOpWithNewOp( - srcOp, outputType, srcOp.getOperands()[0], - srcOp.getOperands()[1], // Start indices + srcOp, outputType, adaptor.getOperands()[0], + adaptor.getOperands()[1], // Start indices Value(outputTensor), dimensionNumbers.getOffsetDims(), dimensionNumbers.getCollapsedSliceDims(), dimensionNumbers.getOperandBatchingDims(), @@ -1150,6 +1638,167 @@ class StableHLOToTTIRGatherOpConversionPattern } }; +template +class StableHLOToTTIROpIotaOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(SrcIotaOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + RankedTensorType outputType = mlir::cast( + this->getTypeConverter()->convertType(srcOp.getResult().getType())); + rewriter.replaceOpWithNewOp( + srcOp, outputType, 0, outputType.getDimSize(adaptor.getIotaDimension()), + 1, adaptor.getIotaDimension()); + + // Dynamic Iota has an output_shape attribute but the output shape is + // already known by the result type This is to remove the operand that will + // become dead code + for (auto operand : adaptor.getOperands()) { + if (operand.getDefiningOp()) { + rewriter.eraseOp(operand.getDefiningOp()); + } + } + + return success(); + } +}; + +class StableHLOToTTIRScatterOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::ScatterOp srcOp, + mlir::stablehlo::ScatterOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outputType = mlir::cast( + this->getTypeConverter()->convertType(srcOp.getResults()[0].getType())); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + Value operand = srcOp.getInputs()[0]; + Value scatterIndices = srcOp.getScatterIndices(); + Value update = srcOp.getUpdates()[0]; + mlir::ArrayAttr binaryConstraints = rewriter.getArrayAttr( + SmallVector(4, rewriter.getAttr( + OperandConstraint::AnyDeviceTile))); + auto updateWindowsDims = + adaptor.getScatterDimensionNumbers().getUpdateWindowDims(); + auto insertedWindowDims = + adaptor.getScatterDimensionNumbers().getInsertedWindowDims(); + auto inputBatchingDims = + adaptor.getScatterDimensionNumbers().getInputBatchingDims(); + auto scatterIndicesBatchingDims = + adaptor.getScatterDimensionNumbers().getScatterIndicesBatchingDims(); + auto scatterDimsToOperandDims = + adaptor.getScatterDimensionNumbers().getScatterDimsToOperandDims(); + auto indexVectorDim = + adaptor.getScatterDimensionNumbers().getIndexVectorDim(); + auto indicesAreSorted = adaptor.getIndicesAreSorted(); + auto uniqueIndices = adaptor.getUniqueIndices(); + + auto newScatterOp = rewriter.create( + srcOp.getLoc(), outputType, operand, scatterIndices, update, + llvm::ArrayRef( + convertArrayRefToInt32vector(updateWindowsDims)), + llvm::ArrayRef( + convertArrayRefToInt32vector(insertedWindowDims)), + llvm::ArrayRef( + convertArrayRefToInt32vector(inputBatchingDims)), + llvm::ArrayRef( + convertArrayRefToInt32vector(scatterIndicesBatchingDims)), + llvm::ArrayRef( + convertArrayRefToInt32vector(scatterDimsToOperandDims)), + indexVectorDim, indicesAreSorted, uniqueIndices, outputTensor, + binaryConstraints); + + // Replaces with different types do not work and will fail silently, so we + // manually set the second operand, since the type changes there from i32 to + // i64. + newScatterOp.setOperand( + 1, adaptor.getScatterIndices().getDefiningOp()->getResult(0)); + + newScatterOp->getRegion(0).takeBody(adaptor.getUpdateComputation()); + changeRegionTypes(newScatterOp->getRegion(0), *getTypeConverter(), + rewriter); + + rewriter.replaceOp(srcOp, newScatterOp); + + return success(); + } + +private: + std::vector + convertArrayRefToInt32vector(const llvm::ArrayRef &source) const { + std::vector converted; + converted.reserve(source.size()); + + for (int64_t value : source) { + converted.push_back(static_cast(value)); + } + + return converted; + } + + void changeRegionTypes(mlir::Region ®ion, + const mlir::TypeConverter &typeConverter, + mlir::PatternRewriter &rewriter) const { + Block &block = *region.getBlocks().begin(); + llvm::SmallVector oldArguments( + block.getArguments().begin(), block.getArguments().end()); + llvm::SmallVector newArguments; + + // Add new arguments with updated types to the block. + for (auto arg : oldArguments) { + if (auto newType = typeConverter.convertType(arg.getType())) { + mlir::BlockArgument newArg = block.addArgument(newType, arg.getLoc()); + newArguments.push_back(newArg); + } else { + newArguments.push_back(arg); // Type didn't change + } + } + + for (auto it : llvm::zip(oldArguments, newArguments)) { + mlir::BlockArgument oldArg = std::get<0>(it); + mlir::Value newArg = std::get<1>(it); + if (oldArg != newArg) { + oldArg.replaceAllUsesWith(newArg); + } + } + + for (auto arg : oldArguments) { + if (!llvm::is_contained(newArguments, arg)) { + block.eraseArgument(arg.getArgNumber()); + } + } + } +}; + +class StableHLOToTTIRReturnOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::ReturnOp srcOp, + mlir::stablehlo::ReturnOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp(srcOp, + srcOp.getResults()); + + return success(); + } +}; + void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -1186,6 +1835,15 @@ void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, mlir::stablehlo::Expm1Op, mlir::tt::ttir::Expm1Op>>(typeConverter, ctx); patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, + ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); } void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx, @@ -1283,6 +1941,13 @@ void addReshapeOpConversionPattern(MLIRContext *ctx, patterns.add(typeConverter, ctx); } +void addCCLOpsConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, + ctx); +} + void addLogicalOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -1314,6 +1979,27 @@ void addGatherOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, patterns.add(typeConverter, ctx); } +void addIotaOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>( + typeConverter, ctx); + patterns + .add>( + typeConverter, ctx); +} + +void addScatterOpConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + +void addReturnOpConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + } // namespace namespace mlir::tt { @@ -1335,9 +2021,13 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addConcatOpsConversionPatterns(ctx, patterns, typeConverter); addReshapeOpConversionPattern(ctx, patterns, typeConverter); addLogicalOpConversionPattern(ctx, patterns, typeConverter); + addCCLOpsConversionPattern(ctx, patterns, typeConverter); addSliceOpConversionPattern(ctx, patterns, typeConverter); addClampOpConversionPattern(ctx, patterns, typeConverter); addGatherOpConversionPattern(ctx, patterns, typeConverter); + addIotaOpConversionPattern(ctx, patterns, typeConverter); + addScatterOpConversionPatterns(ctx, patterns, typeConverter); + addReturnOpConversionPatterns(ctx, patterns, typeConverter); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index 9b8c634adb..9ba4257428 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -16,6 +16,7 @@ #include "mlir/Transforms/DialectConversion.h" #include +#include using namespace mlir; using namespace mlir::tt; @@ -223,34 +224,22 @@ static std::vector generateConvKernelTransposeIndices( return generateTransposeIndices(kernelLayout, ttnnConvolutionKernelLayout); } -struct ConvolutionToConv2dPattern +struct ConvolutionDecompositionPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - constexpr static uint32_t numSpatialDims = 2; - constexpr static uint32_t SPATIAL_DIM_HEIGHT = 0; - constexpr static uint32_t SPATIAL_DIM_WIDTH = 1; - - // NHWC - static inline const std::vector conv2dLayout = { - ConvolutionDimension::BATCH, - SPATIAL_DIM_HEIGHT, - SPATIAL_DIM_WIDTH, - ConvolutionDimension::FEATURE, - }; - // OIHW - static inline const std::vector conv2dKernelLayout = { - ConvolutionKernelDimension::OUTPUT_FEATURES, - ConvolutionKernelDimension::INPUT_FEATURES, - SPATIAL_DIM_HEIGHT, - SPATIAL_DIM_WIDTH, - }; - - LogicalResult isConv2d(ttir::ConvolutionOp op) const { + LogicalResult + matchAndRewrite(ttir::ConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override = 0; - // Conv2d will have 2 spatial dimensions +protected: + bool isNDimensional(ttir::ConvolutionOp op, uint32_t numSpatialDims) const { + return op.getConvolutionLayout().getInputSpatialDimensions().size() == + numSpatialDims; + } + bool isSupportedConv(ttir::ConvolutionOp op) const { assert(op.getConvolutionLayout().getInputSpatialDimensions().size() == op.getConvolutionLayout().getOutputSpatialDimensions().size() && "Convolution input, output, and kernel must have the same number of " @@ -260,33 +249,211 @@ struct ConvolutionToConv2dPattern "Convolution input, output, and kernel must have the same number of " "spatial dimensions"); - if (op.getConvolutionLayout().getInputSpatialDimensions().size() != - numSpatialDims) { - return failure(); - } - // Not currently supporting window reversal std::vector windowReversal(op.getWindowReversal().begin(), op.getWindowReversal().end()); for (bool reversed : windowReversal) { if (reversed) { - return failure(); + return false; } } // Not currently support batch groups if (op.getBatchGroupCount() != 1) { + return false; + } + + return true; + } +}; + +// A decompostion pattern that matches to a ttir.convolution op that does 1D +// convolution. Since that is not supported in ttnn, we reshape the inputs and +// the output to match a 2D ttir.convolution op. The expectation is that the new +// ttir.convolution op will be picked up by the ConvolutionToConv2dPattern and +// translated into ttir.conv2d op. +struct Legalize1DConvolutionPattern : public ConvolutionDecompositionPattern { +public: + using ConvolutionDecompositionPattern::ConvolutionDecompositionPattern; + constexpr static uint32_t numSpatialDims = 1; + + LogicalResult + matchAndRewrite(ttir::ConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!(isSupportedConv(op) && isNDimensional(op, numSpatialDims))) { return failure(); } + // Not currently supporting spatial dims other than 2 for the 1D case. + if (op.getConvolutionLayout().getInputSpatialDimensions()[0] != 2) { + return failure(); + } + + // The shapes that the convolution currently operates with have are 3D, and + // we need to add another dimension for it to match the conv2d signature, so + // adding a dimension of size 1 to the end of input and output shapes. + auto outputType = + mlir::cast(adaptor.getOutput().getType()); + llvm::ArrayRef outputShape = outputType.getShape(); + llvm::SmallVector conv2dOutputShape(outputShape.begin(), + outputShape.end()); + conv2dOutputShape.push_back(1); + auto DPSConv2dOutput = rewriter.create( + op->getLoc(), conv2dOutputShape, outputType.getElementType()); + auto conv2dOutputType = + mlir::cast(DPSConv2dOutput.getType()); + + auto inputType = mlir::cast(adaptor.getInput().getType()); + llvm::ArrayRef inputShape = inputType.getShape(); + llvm::SmallVector reshapeInputShape(inputShape.begin(), + inputShape.end()); + reshapeInputShape.push_back(1); + + auto weightType = + mlir::cast(adaptor.getWeight().getType()); + llvm::ArrayRef weightShape = weightType.getShape(); + llvm::SmallVector reshapeWeightShape(weightShape.begin(), + weightShape.end()); + reshapeWeightShape.push_back(1); + + ttir::ReshapeOp reshapeInput = + createReshapeOp(op.getLoc(), adaptor.getInput(), reshapeInputShape, + op.getOperandConstraints(), rewriter); + ttir::ReshapeOp reshapeWeight = + createReshapeOp(op.getLoc(), adaptor.getWeight(), reshapeWeightShape, + op.getOperandConstraints(), rewriter); + + mlir::DenseI64ArrayAttr conv2dOpWindowsStridesAttr = + addIntegerToDenseArrayAttr(rewriter, adaptor.getWindowStridesAttr(), 1); + mlir::DenseI64ArrayAttr conv2dOpPaddingAttr = + addIntegerToDenseArrayAttr(rewriter, adaptor.getPaddingAttr(), 0); + conv2dOpPaddingAttr = + addIntegerToDenseArrayAttr(rewriter, conv2dOpPaddingAttr, 0); + mlir::DenseI64ArrayAttr conv2dOpInputDilationAttr = + addIntegerToDenseArrayAttr(rewriter, adaptor.getInputDilationAttr(), 1); + mlir::DenseI64ArrayAttr conv2dOpWeightDilationAttr = + addIntegerToDenseArrayAttr(rewriter, adaptor.getWeightDilationAttr(), + 1); + mlir::DenseBoolArrayAttr conv2dOpWindowReversalAttr = + addBooleanToDenseArrayAttr(rewriter, adaptor.getWindowReversalAttr(), + false); + + auto convolutionLayout = adaptor.getConvolutionLayoutAttr(); + + // The additional spatial dimension is added at the and (3rd in 0 indexed + // array). + llvm::SmallVector conv2dInputSpatialDimensions( + convolutionLayout.getInputSpatialDimensions().begin(), + convolutionLayout.getInputSpatialDimensions().end()); + conv2dInputSpatialDimensions.push_back(3); + + llvm::SmallVector conv2dKernelSpatialDimensions( + convolutionLayout.getKernelSpatialDimensions().begin(), + convolutionLayout.getKernelSpatialDimensions().end()); + conv2dKernelSpatialDimensions.push_back(3); + + llvm::SmallVector conv2dOutputSpatialDimensions( + convolutionLayout.getOutputSpatialDimensions().begin(), + convolutionLayout.getOutputSpatialDimensions().end()); + conv2dOutputSpatialDimensions.push_back(3); + + mlir::tt::ttir::ConvolutionOp new2dConvolutionOp = + rewriter.create( + op.getLoc(), conv2dOutputType, reshapeInput, reshapeWeight, + mlir::Value(nullptr), DPSConv2dOutput, conv2dOpWindowsStridesAttr, + conv2dOpPaddingAttr, conv2dOpInputDilationAttr, + conv2dOpWeightDilationAttr, conv2dOpWindowReversalAttr, + mlir::tt::ttir::ConvolutionLayoutAttr::get( + getContext(), convolutionLayout.getInputBatchDimension(), + convolutionLayout.getInputFeatureDimension(), + conv2dInputSpatialDimensions, + convolutionLayout.getKernelOutputFeatureDimension(), + convolutionLayout.getKernelInputFeatureDimension(), + conv2dKernelSpatialDimensions, + convolutionLayout.getOutputBatchDimension(), + convolutionLayout.getOutputFeatureDimension(), + conv2dOutputSpatialDimensions), + adaptor.getFeatureGroupCountAttr(), + adaptor.getBatchGroupCountAttr(), + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + ttir::ReshapeOp reshapeOutput = + createReshapeOp(op.getLoc(), new2dConvolutionOp, outputShape, + op.getOperandConstraints(), rewriter); + + rewriter.replaceOp(op, reshapeOutput); + return success(); } +private: + ttir::ReshapeOp createReshapeOp(Location loc, Value tensor, + llvm::ArrayRef target_input_shape, + ::mlir::ArrayAttr constraints, + ConversionPatternRewriter &rewriter) const { + auto inputType = mlir::cast(tensor.getType()); + + auto DPSReshapeOutput = rewriter.create( + loc, llvm::ArrayRef(target_input_shape), + inputType.getElementType()); + llvm::SmallVector shapei32(target_input_shape.begin(), + target_input_shape.end()); + auto shape_attr = rewriter.getI32ArrayAttr(shapei32); + + return rewriter.create( + loc, + mlir::RankedTensorType::get(target_input_shape, + inputType.getElementType()), + tensor, DPSReshapeOutput, shape_attr, constraints); + } + + mlir::DenseI64ArrayAttr + addIntegerToDenseArrayAttr(ConversionPatternRewriter &rewriter, + mlir::DenseI64ArrayAttr denseArrayAttr, + uint64_t integerValue) const { + llvm::SmallVector newDenseArray(denseArrayAttr.asArrayRef()); + newDenseArray.push_back(integerValue); + return rewriter.getDenseI64ArrayAttr(newDenseArray); + } + + mlir::DenseBoolArrayAttr + addBooleanToDenseArrayAttr(ConversionPatternRewriter &rewriter, + mlir::DenseBoolArrayAttr denseArrayAttr, + bool booleanValue) const { + llvm::SmallVector newDenseArray(denseArrayAttr.asArrayRef()); + newDenseArray.push_back(booleanValue); + return rewriter.getDenseBoolArrayAttr(newDenseArray); + } +}; +struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { +public: + using ConvolutionDecompositionPattern::ConvolutionDecompositionPattern; + + constexpr static uint32_t numSpatialDims = 2; + constexpr static uint32_t SPATIAL_DIM_HEIGHT = 0; + constexpr static uint32_t SPATIAL_DIM_WIDTH = 1; + + // NHWC + static inline const std::vector conv2dLayout = { + ConvolutionDimension::BATCH, + SPATIAL_DIM_HEIGHT, + SPATIAL_DIM_WIDTH, + ConvolutionDimension::FEATURE, + }; + // OIHW + static inline const std::vector conv2dKernelLayout = { + ConvolutionKernelDimension::OUTPUT_FEATURES, + ConvolutionKernelDimension::INPUT_FEATURES, + SPATIAL_DIM_HEIGHT, + SPATIAL_DIM_WIDTH, + }; + LogicalResult matchAndRewrite(ttir::ConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - if (failed(isConv2d(op))) { + if (!(isSupportedConv(op) && isNDimensional(op, numSpatialDims))) { return failure(); } @@ -407,6 +574,13 @@ struct GatherToEmbeddingConversionPattern // collapsed slice dims of the gather op auto collapsedSliceDims = op.getCollapsedSliceDims(); + RankedTensorType operandType = + mlir::cast(op->getOperand(0).getType()); + if (!operandType.getElementType().isBF16()) { + return rewriter.notifyMatchFailure( + op, "only supports bfloat16 input tensor."); + } + if (shape.size() > 1) { auto hiddenDim = shape[shape.size() - 1]; // check if sliceSizes has more than one element @@ -775,14 +949,268 @@ class GetDimensionSizeToConstantConversionPattern } }; +// SelectOp is converted to a series of SliceOp and potentially a ConcatOp if +// the sliced dimension is sliced multiple times. For example, if the input +// tensor is +// [[[1, 2, 3], +// [4, 5, 6], +// [7, 8, 9], +// [10, 11, 12], +// [13, 14, 15], +// [16, 17, 18]], +// [[19, 20, 21], +// [22, 23, 24], +// [25, 26, 27], +// [28, 29, 30], +// [31, 32, 33], +// [34, 35, 36]]], +// shape = [2, 6, 3] +// and the SelectOp is dim=1, begin=0, length=2, stride=4, the output tensor +// will be +// [[[1, 2, 3], +// [4, 5, 6], +// [13, 14, 15], +// [16, 17, 18]], +// [[19, 20, 21], +// [22, 23, 24], +// [31, 32, 33], +// [34, 35, 36]]], +// shape = [2, 4, 3] +// In this case 2 slices are created and concatenated to form the output tensor. +// First slice has begins=[0, 0, 0], ends=[2, 2, 3], steps=[1, 1, 1], and the +// second slice has begins=[0, 4, 0], ends=[2, 6, 3], steps=[1, 1, 1]. +struct SelectToSliceConversionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto inputType = mlir::cast(adaptor.getInput().getType()); + auto outputType = mlir::cast(op.getType()); + + auto inputShape = inputType.getShape(); + + int32_t dim = + op.getDim() < 0 ? inputType.getRank() + op.getDim() : op.getDim(); + + int32_t begin = op.getBegin(); + int32_t length = op.getLength(); + int32_t stride = op.getStride(); + + int32_t inputDimSize = inputType.getShape()[dim]; + int32_t numSlices = (inputDimSize - begin + stride - 1) / stride; + + llvm::SmallVector begins, ends, steps; + for (int32_t i = 0; i < inputType.getRank(); ++i) { + // Always slicing with step 1. + steps.push_back(1); + if (i == dim) { + // Push placeholder values for now which will be updated later. + begins.push_back(0); + ends.push_back(0); + continue; + } + + // For non-sliced dimensions, begin=0, end=dimSize, step=1. + begins.push_back(0); + ends.push_back(inputType.getDimSize(i)); + } + + // Create a slice for each slice of the input tensor. The slices are then + // concatenated. The slices are created by updating the begin and end values + // for the sliced dimension. + llvm::SmallVector slices; + for (int32_t i = 0; i < numSlices; ++i) { + int32_t newBegin = begin + i * stride; + int32_t newEnd = std::min(newBegin + length, inputDimSize); + + // Make a copy of the input shape and update the dim size. + llvm::SmallVector resultShape(inputShape); + resultShape[dim] = newEnd - newBegin; + auto resultType = + RankedTensorType::get(resultShape, inputType.getElementType()); + + auto sliceDpsResult = rewriter.create( + op.getLoc(), resultShape, inputType.getElementType()); + + begins[dim] = newBegin; + ends[dim] = newEnd; + + auto newOp = rewriter.create( + op.getLoc(), resultType, adaptor.getInput(), sliceDpsResult, + rewriter.getI32ArrayAttr(begins), rewriter.getI32ArrayAttr(ends), + rewriter.getI32ArrayAttr(steps), adaptor.getOperandConstraints()); + slices.push_back(newOp->getResult(0)); + } + + assert(!slices.empty()); + if (slices.size() > 1) { + auto concatDpsResult = rewriter.create( + op.getLoc(), outputType.getShape(), outputType.getElementType()); + auto concatOp = rewriter.create( + op.getLoc(), outputType, slices, concatDpsResult, + rewriter.getSI32IntegerAttr(dim), adaptor.getOperandConstraints()); + + rewriter.replaceOp(op, concatOp.getResult()); + } else { + rewriter.replaceOp(op, slices[0]); + } + + return success(); + } +}; + +/* + * This pattern rewrites ArangeOp by forcing the arange_dimension to be + * rightmost dimension of the output tensor. This is done by replacing the + * ArangeOp with a new one that has this property, and then transposing out last + * dimension to the dimension specified by the original ArangeOp, and also + * inserting a reshape to match the rank of the intended output and broadcasts + * to repeat the data along the other dimensions. + * + * The ArangeOp that is generated here will be equivalent to how ttnn::ArangeOp + * behaves. The reason this pass is done in TTIR rather than generated when we + * want to lower to TTNN is because in the future we will want to consteval the + * ArangeOp, but have the option to not include repeated data in the constant + * tensor and broadcast at runtime instead. Consteval will be implemented for + * the TTIR dialect only and so this explication of the TMs implicit in ArangeOp + * must be done in TTIR. + */ +struct ArangeForceLastDimensionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::ArangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + const RankedTensorType outputType = + mlir::cast(op.getResult().getType()); + + int64_t arangeDimension = adaptor.getArangeDimension(); + int64_t arangeDimensionNegative = arangeDimension - outputType.getRank(); + int64_t start = adaptor.getStart(); + int64_t end = adaptor.getEnd(); + int64_t step = adaptor.getStep(); + + int64_t arangeLength = (end - start) / step; + + ArrayRef ttnnShape = {1, 1, 1, arangeLength}; + if (ttnnShape == outputType.getShape()) { + return success(); + } + + RankedTensorType arangeOutputType = RankedTensorType::get( + SmallVector({1, 1, 1, arangeLength}), + outputType.getElementType(), outputType.getEncoding()); + + Value output = + rewriter + .create( // perform arange on the last dimension to + // match how ttnn behaves + op.getLoc(), arangeOutputType, start, end, step, 3) + .getResult(); + + std::vector outputShape = arangeOutputType.getShape().vec(); + // Must transpose the output so that the data changes along the axis defined + // by arangeDimension + if (arangeDimensionNegative != -1) { + std::vector transposeShape = outputShape; + transposeShape[arangeDimensionNegative + transposeShape.size()] = + arangeLength; + transposeShape[arangeOutputType.getRank() - 1] = 1; + RankedTensorType transposeType = RankedTensorType::get( + transposeShape, arangeOutputType.getElementType(), + arangeOutputType.getEncoding()); + + tensor::EmptyOp dpsOutput = rewriter.create( + op.getLoc(), transposeShape, transposeType.getElementType()); + + output = rewriter.create( + op.getLoc(), transposeType, output, dpsOutput, + arangeDimensionNegative + transposeShape.size(), + arangeOutputType.getRank() - 1, + rewriter.getArrayAttr(SmallVector( + 2, rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + outputShape = transposeShape; + } + + // Must match up the rank of the output with the rank of the intended output + // from the original arange, with the arangeDimension in the correct + // position + if (outputType.getRank() != static_cast(outputShape.size())) { + std::vector reshapeShape; + for (uint32_t i = 0; i < outputType.getRank(); i++) { + i == arangeDimension ? reshapeShape.push_back(end) + : reshapeShape.push_back(1); + } + + RankedTensorType reshapeType = RankedTensorType::get( + SmallVector(reshapeShape.begin(), reshapeShape.end()), + outputType.getElementType(), outputType.getEncoding()); + tensor::EmptyOp dpsOutput = rewriter.create( + op.getLoc(), + SmallVector(reshapeShape.begin(), reshapeShape.end()), + reshapeType.getElementType()); + output = rewriter.create( + op.getLoc(), reshapeType, output, dpsOutput, + rewriter.getI32ArrayAttr(reshapeShape), + rewriter.getArrayAttr(SmallVector( + 2, rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + outputShape = + std::vector(reshapeShape.begin(), reshapeShape.end()); + } + + // Must broadcast the rest of the dimensions + SmallVector broadcastDims; + for (uint32_t i = 0; i < outputShape.size(); i++) { + if (i != arangeDimension && outputShape[i] != outputType.getShape()[i]) { + outputShape[i] = outputType.getShape()[i]; + broadcastDims.push_back(rewriter.getI64IntegerAttr(i)); + } + } + if (!broadcastDims.empty()) { + RankedTensorType broadcastType = RankedTensorType::get( + outputShape, outputType.getElementType(), outputType.getEncoding()); + + tensor::EmptyOp dpsOutput = rewriter.create( + op.getLoc(), outputShape, outputType.getElementType()); + + output = rewriter.create( + op.getLoc(), broadcastType, output, dpsOutput, + rewriter.getArrayAttr(broadcastDims), + rewriter.getArrayAttr(SmallVector( + 2, rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + assert(mlir::cast(output.getType()).getShape() == + outputType.getShape() && + "Output shape must match the shape of the input tensor"); + } + rewriter.replaceOp(op, output); + return success(); + } +}; + void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp index 76cbae96e2..e244eea8fb 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp @@ -51,6 +51,15 @@ struct TTIRToTTIRDecompositionPass target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + + // These are the ops that must satisfy some conditions after this pass + target.addDynamicallyLegalOp([&](ttir::ArangeOp op) { + auto shape = op.getResult().getType().getShape(); + return (static_cast(op.getArangeDimension()) == 3 && + shape.size() == 4 && shape[0] == 1 && shape[1] == 1 && + shape[2] == 1); + }); TypeConverter typeConverter; // All types map 1:1. diff --git a/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp b/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp index a3bbddc1da..00220d7e44 100644 --- a/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp +++ b/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp @@ -199,8 +199,8 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { LogicalResult relayout(ttir::ToLayoutOp op, PatternRewriter &rewriter) const { auto inputTy = mlir::cast(op.getInput().getType()); auto outputTy = mlir::cast(op.getType()); - auto inputLayout = mlir::cast(inputTy.getEncoding()); - auto outputLayout = mlir::cast(outputTy.getEncoding()); + auto inputLayout = mlir::cast(inputTy.getEncoding()); + auto outputLayout = mlir::cast(outputTy.getEncoding()); tt::DeviceAttr device = op.getDevice(); assert(device); tt::SystemDescAttr systemDesc = op.getSystemDesc(); @@ -342,8 +342,8 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { LogicalResult reformat(ttir::ToLayoutOp op, PatternRewriter &rewriter) const { auto inputTy = mlir::cast(op.getInput().getType()); auto outputTy = mlir::cast(op.getType()); - auto inputLayout = mlir::cast(inputTy.getEncoding()); - auto outputLayout = mlir::cast(outputTy.getEncoding()); + auto inputLayout = mlir::cast(inputTy.getEncoding()); + auto outputLayout = mlir::cast(outputTy.getEncoding()); bool shouldTilize = not inputLayout.isTiled() && outputLayout.isTiled(); bool shouldUntilize = inputLayout.isTiled() && not outputLayout.isTiled(); assert(shouldTilize ^ shouldUntilize); @@ -448,10 +448,10 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { return failure(); } assert(inputTy.getShape() == outputTy.getShape()); - assert(mlir::isa(inputTy.getEncoding())); - assert(mlir::isa(outputTy.getEncoding())); - auto inputLayout = mlir::cast(inputTy.getEncoding()); - auto outputLayout = mlir::cast(outputTy.getEncoding()); + assert(mlir::isa(inputTy.getEncoding())); + assert(mlir::isa(outputTy.getEncoding())); + auto inputLayout = mlir::cast(inputTy.getEncoding()); + auto outputLayout = mlir::cast(outputTy.getEncoding()); auto components = op.compoundComponents(); bool isCompound = (static_cast(components.isLayoutChange) + @@ -799,6 +799,8 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { inCB1); } else if (mlir::isa(arithOrMathOp)) { builder.create(arithOrMathOp.getLoc()); + } else if (mlir::isa(arithOrMathOp)) { + builder.create(arithOrMathOp.getLoc()); } else { llvm_unreachable("Unhandled binary op init conversion."); } @@ -905,27 +907,13 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { assert(cbOperands.size() == 3 && "Expected two input and one output CB for binary op."); - auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]]; - auto inCB0 = cbOperands[0]; - auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]]; - auto inCB1 = cbOperands[1]; - auto outCB = cbOperands[2]; - auto outCBTileIndex = iterators[blockArgIteratorMapping[2]]; - - auto location = arithOrMathOp.getLoc(); - - // Perform computation C = A (*) B on tile A from inCB0 and tile B from - // inCB1 and store the result C in DST register on dstTileIndex. + // Perform computation C = A (*) B on tile A from cbOperands[0] and tile B + // from cbOperands[1] and store the result C in DST register on + // dstTileIndex. if (mlir::isa(arithOrMathOp)) { - Value dstIndex = i32(0, builder); - builder.create(location); - builder.create( - location, inCB0, inCB1, inCB0TileIndex, inCB1TileIndex, dstIndex); - builder.create(location); - builder.create(location); - builder.create(location, dstIndex, outCB, - outCBTileIndex); - builder.create(location); + convertComputeBinaryFPUOp( + arithOrMathOp, cbOperands, iterators, blockArgIteratorMapping, + builder); } else if (mlir::isa(arithOrMathOp)) { commonComputeMulOp(arithOrMathOp, cbOperands, iterators, blockArgIteratorMapping, builder); @@ -938,6 +926,10 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { blockArgIteratorMapping, builder, operandIndicesRecip); + auto inCB0 = cbOperands[0]; + auto inCB1 = cbOperands[1]; + auto location = arithOrMathOp.getLoc(); + Value one = i32(1, builder); builder.create(location, inCB1, one); @@ -947,12 +939,96 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { blockArgIteratorMapping, builder); builder.create(location, inCB1, one); + } else if (mlir::isa(arithOrMathOp)) { + convertComputeBinarySFPUOp( + arithOrMathOp, cbOperands, iterators, blockArgIteratorMapping, + builder); } else { llvm_unreachable("Unhandled conversion for operation which is neither " "unary nor binary."); } } + template + void convertComputeBinaryFPUOp( + Operation &arithOrMathOp, ArrayRef cbOperands, + ArrayRef iterators, + const SmallVector &blockArgIteratorMapping, + OpBuilder &builder) const { + auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]]; + auto inCB0 = cbOperands[0]; + auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]]; + auto inCB1 = cbOperands[1]; + auto outCB = cbOperands[2]; + auto outCBTileIndex = iterators[blockArgIteratorMapping[2]]; + + auto location = arithOrMathOp.getLoc(); + + Value dstIndex = i32(0, builder); + + // acquire DST register lock (MATH) + builder.create(location); + { + builder.create(location, inCB0, inCB1, inCB0TileIndex, + inCB1TileIndex, dstIndex); + } + builder.create(location); + // release DST register lock (MATH) + + // acquire DST register lock (PACK) + builder.create(location); + { + builder.create(location, dstIndex, outCB, + outCBTileIndex); + } + builder.create(location); + // release DST register lock (PACK) + } + + template + void convertComputeBinarySFPUOp( + Operation &arithOrMathOp, ArrayRef cbOperands, + ArrayRef iterators, + const SmallVector &blockArgIteratorMapping, + OpBuilder &builder) const { + auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]]; + auto inCB0 = cbOperands[0]; + auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]]; + auto inCB1 = cbOperands[1]; + auto outCB = cbOperands[2]; + auto outCBTileIndex = iterators[blockArgIteratorMapping[2]]; + + auto location = arithOrMathOp.getLoc(); + + Value dstLhsTileIndex = i32(0, builder); + Value dstRhsTileIndex = i32(1, builder); // note: rhs is always lhs+1 + + // acquire DST register lock (MATH) + builder.create(location); + { + // copy inCB0[inCB0TileIndex] and inCB1[inCB1TileIndex] to DST: + builder.create(location); + builder.create(location, inCB0, inCB0TileIndex, + dstLhsTileIndex); + builder.create(location, inCB1, inCB1TileIndex, + dstRhsTileIndex); + // SFPU operates on DST tiles: + builder.create(location, dstLhsTileIndex, + dstRhsTileIndex); + } + builder.create(location); + // release DST register lock (MATH) + + // acquire DST register lock (PACK) + builder.create(location); + { + builder.create(location, dstLhsTileIndex, outCB, + outCBTileIndex); + } + builder.create(location); + // release DST register lock (PACK) + } + void commonComputeMulOp(Operation &op, ArrayRef cbOperands, ArrayRef iterators, SmallVector blockArgIteratorMapping, @@ -1308,10 +1384,10 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern { SmallVector> calculateDataMovement(ArrayAttr iteratorTypes, const RankedTensorType &src, const RankedTensorType &dst, DeviceAttr device) const { - auto srcLayout = mlir::cast(src.getEncoding()); + auto srcLayout = mlir::cast(src.getEncoding()); assert(srcLayout.isTiled()); - auto dstLayout = mlir::cast(dst.getEncoding()); + auto dstLayout = mlir::cast(dst.getEncoding()); assert(dstLayout.isTiled()); assert(iteratorTypes.size() >= 2 && "Expected at least 2 iterator types"); diff --git a/lib/Conversion/TTIRToTTNN/CMakeLists.txt b/lib/Conversion/TTIRToTTNN/CMakeLists.txt index e31220d751..ff054f5bd7 100644 --- a/lib/Conversion/TTIRToTTNN/CMakeLists.txt +++ b/lib/Conversion/TTIRToTTNN/CMakeLists.txt @@ -11,4 +11,5 @@ add_mlir_library(TTMLIRTTIRToTTNN LINK_LIBS PUBLIC MLIRIR MLIRPass + TTMLIRTTNNUtils ) diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 12e29a9609..db4320ff3f 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -9,6 +9,7 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/Types/Types.h" +#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -22,33 +23,13 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" +#include using namespace mlir; using namespace mlir::tt; namespace { -// Gets or inserts a GetDeviceOp at the top of the current block of the given -// operation. -static Value getOrInsertDevice(ConversionPatternRewriter &rewriter, - Operation *op) { - Block *block = op->getBlock(); - for (auto &op : block->getOperations()) { - if (auto deviceOp = dyn_cast(op)) { - return deviceOp.getResult(); - } - } - - DeviceAttr deviceAttr = getCurrentScopeDevice(op); - auto currentInsertionPoint = rewriter.saveInsertionPoint(); - rewriter.setInsertionPoint(block, block->begin()); - auto deviceOp = rewriter.create( - op->getLoc(), rewriter.getType(deviceAttr), - ttnn::MeshShapeAttr::get(op->getContext(), 1, 1)); - rewriter.restoreInsertionPoint(currentInsertionPoint); - return deviceOp.getResult(); -} - class TensorEmptyConversionPattern : public OpConversionPattern { public: @@ -65,20 +46,15 @@ class TensorEmptyConversionPattern // Get the shape of the tensor, tensor layout, and data type // - mlir::MemRefType memref = layoutAttr.getMemref(); ttnn::ShapeAttr shapeAttr = ttnn::ShapeAttr::get( rewriter.getContext(), mlir::cast(op->getResult(0).getType()).getShape()); - Type elementType = memref.getElementType(); - DataType dtype = DataType::Float32; + DataType dtype = layoutAttr.getDataType(); ttnn::Layout ttnnLayoutEnum = ttnn::Layout::RowMajor; - if (llvm::isa(elementType)) { + if (layoutAttr.isTiled()) { ttnnLayoutEnum = ttnn::Layout::Tile; - auto tileType = mlir::cast(elementType); - dtype = tileType.getDataType(); } else { ttnnLayoutEnum = ttnn::Layout::RowMajor; - dtype = elementTypeToDataType(elementType); } DataTypeAttr dTypeAttr = DataTypeAttr::get(rewriter.getContext(), dtype); ttnn::LayoutAttr tensorLayoutAttr = @@ -87,8 +63,8 @@ class TensorEmptyConversionPattern // If the tensor is not going to device, we can create the op without // device-specific attributes // - ttnn::TensorMemoryLayout memLayout = layoutAttr.getMemLayout(); - if (memLayout == ttnn::TensorMemoryLayout::None) { + ttnn::TensorMemoryLayoutAttr memLayout = layoutAttr.getMemLayout(); + if (!memLayout) { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), nullptr, shapeAttr, dTypeAttr, tensorLayoutAttr, nullptr); @@ -100,14 +76,13 @@ class TensorEmptyConversionPattern // Create MemoryConfigAttr // - auto device = getOrInsertDevice(rewriter, op); + auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op); + llvm::SmallVector shardShape = layoutAttr.getShardShape(); ttnn::MemoryConfigAttr memoryConfigAttr = ttnn::MemoryConfigAttr::get( - op.getContext(), - ttnn::TensorMemoryLayoutAttr::get(op.getContext(), memLayout), - ttnn::BufferTypeAttr::get(op.getContext(), bufferType), + op.getContext(), ttnn::BufferTypeAttr::get(op.getContext(), bufferType), ttnn::ShardSpecAttr::get( - op.getContext(), - ttnn::ShapeAttr::get(op.getContext(), memref.getShape()))); + op.getContext(), ttnn::ShapeAttr::get(op.getContext(), shardShape)), + memLayout); rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), device, @@ -137,18 +112,15 @@ class ToLayoutOpConversionPattern auto outputLayoutAttr = mlir::cast( op.getResult().getType().getEncoding()); - auto outputMemref = outputLayoutAttr.getMemref(); - // Determine the output data type - DataType dtype = ttnn::utils::getDataTypeFromMemRef(outputMemref); + DataType dtype = outputLayoutAttr.getDataType(); DataTypeAttr outputDataType = DataTypeAttr::get(rewriter.getContext(), dtype); // Determine the output layout (tile or row major) ttnn::BufferType outputBufferType = outputLayoutAttr.getBufferType(); - ttnn::Layout outputLayoutEnum = - ttnn::utils::getLayoutFromMemRef(outputMemref); + ttnn::Layout outputLayoutEnum = outputLayoutAttr.getLayout(); bool isOutputOnHost = (outputBufferType == ttnn::BufferType::SystemMemory); @@ -176,30 +148,28 @@ class ToLayoutOpConversionPattern op.getResult().setType(result); outputLayoutAttr = mlir::cast(result.getEncoding()); - outputMemref = outputLayoutAttr.getMemref(); outputLayoutEnum = newOutputLayoutEnum; } } ttnn::LayoutAttr outputLayout = ttnn::LayoutAttr::get(rewriter.getContext(), outputLayoutEnum); + llvm::SmallVector outputShardShape = + outputLayoutAttr.getShardShape(); - // Determine output memory config attr - ttnn::TensorMemoryLayout outputTensorMemoryLayout = - outputLayoutAttr.getMemLayout(); ttnn::MemoryConfigAttr outputMemConfigAttr = ttnn::MemoryConfigAttr::get( rewriter.getContext(), - ttnn::TensorMemoryLayoutAttr::get(rewriter.getContext(), - outputTensorMemoryLayout), ttnn::BufferTypeAttr::get(rewriter.getContext(), outputBufferType), ttnn::ShardSpecAttr::get( - op.getContext(), ttnn::ShapeAttr::get(rewriter.getContext(), - outputMemref.getShape()))); + op.getContext(), + ttnn::ShapeAttr::get(rewriter.getContext(), outputShardShape)), + outputLayoutAttr.getMemLayout()); rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(result), adaptor.getInput(), outputLayout, outputDataType, outputMemConfigAttr, - isOutputOnHost ? nullptr : getOrInsertDevice(rewriter, op)); + isOutputOnHost ? nullptr + : ::ttnn::utils::getOrInsertDevice(rewriter, op)); return success(); } @@ -222,15 +192,16 @@ class ToLayoutOpConversionPattern ttnn::Layout newOutputLayoutEnum) const { auto oldOutputLayoutAttr = mlir::cast(oldOutput.getEncoding()); - auto oldOutputMemref = oldOutputLayoutAttr.getMemref(); - DataType outputDtype = ttnn::utils::getDataTypeFromMemRef(oldOutputMemref); - llvm::ArrayRef oldShardShape = oldOutputMemref.getShape(); + DataType outputDtype = oldOutputLayoutAttr.getDataType(); + SmallVector oldShardShape = + oldOutputLayoutAttr.getShardShape(); size_t shardShapeSize = oldShardShape.size(); assert(shardShapeSize >= 2 && "expected at least 2D shape"); if (newOutputLayoutEnum == ttnn::Layout::RowMajor) { // Set shard shape to match convention of row major layout - auto tileType = mlir::cast(oldOutputMemref.getElementType()); + auto tileType = + mlir::cast(oldOutputLayoutAttr.getElementType()); llvm::SmallVector newShardShape(oldShardShape.begin(), oldShardShape.end()); newShardShape[shardShapeSize - 2] = @@ -358,6 +329,78 @@ class ClampOpConversionPattern : public OpConversionPattern { } }; +class UpdateCacheOpConversionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::UpdateCacheOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // The TTIR version of this op is pure. In TTNN this op is in-place. + // We need to replace uses of the result ot the TTIR op with uses + // of the cache argument. + // + // The presence of the MemWrite trait of this op should preserve + // the order of this op relative to the cache arguments uses, preserving + // program correctness. + + // This op can only work if it is the final use of the cache tensor in the + // order of execution. For now, checking that there is only one user (this + // op) of the cache tensor will suffice. + std::vector users(op.getCache().getUsers().begin(), + op.getCache().getUsers().end()); + if (users.size() != 1) { + return rewriter.notifyMatchFailure( + op, "UpdateCacheOp must have exactly one user"); + } + + rewriter.create( + op.getLoc(), adaptor.getCache(), adaptor.getInput(), + adaptor.getUpdateIndex(), adaptor.getBatchOffset()); + + rewriter.replaceOp(op, adaptor.getCache()); + return success(); + } +}; + +class FillCacheOpConversionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::FillCacheOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // The TTIR version of this op is pure. In TTNN this op is in-place. + // We need to replace uses of the result ot the TTIR op with uses + // of the cache argument. + // + // The presence of the MemWrite trait of this op should preserve + // the order of this op relative to the cache arguments uses, preserving + // program correctness. + + // This op can only work if it is the final use of the cache tensor in the + // order of execution. For now, checking that there is only one user (this + // op) of the cache tensor will suffice. + std::vector users(op.getCache().getUsers().begin(), + op.getCache().getUsers().end()); + if (users.size() != 1) { + return rewriter.notifyMatchFailure( + op, "FillCacheOp must have exactly one user"); + } + + rewriter.create(op.getLoc(), adaptor.getCache(), + adaptor.getInput(), + adaptor.getBatchOffset()); + + rewriter.replaceOp(op, adaptor.getCache()); + return success(); + } +}; + template class ElementwiseUnaryWithFloatParameterOpConversionPattern @@ -525,20 +568,17 @@ class ConstantOpConversionPattern } if (valueAttr.isSplat()) { - Value device = getOrInsertDevice(rewriter, op); + Value device = ::ttnn::utils::getOrInsertDevice(rewriter, op); float fillValue = valueAttr.getElementType().isInteger() ? getIntegerValue(valueAttr) : valueAttr.getSplatValue().convertToFloat(); - if (fillValue == 0) { - rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType()), device); - } else { - ::mlir::FloatAttr fillValueAttr = rewriter.getF32FloatAttr(fillValue); - rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType()), device, - fillValueAttr); - } + + ::mlir::FloatAttr fillValueAttr = rewriter.getF32FloatAttr(fillValue); + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), device, + fillValueAttr); + } else { return rewriter.notifyMatchFailure( op, "TTNN doesn't currently support tensor creation from multiple " @@ -579,7 +619,19 @@ class ConstantOpConversionPattern } }; -} // namespace +class LinearOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::LinearOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(), + adaptor.getB(), adaptor.getBias(), adaptor.getOutput()); + return success(); + } +}; // ANCHOR: adding_an_op_matmul_op_rewriter class MatmulOpConversionPattern : public OpConversionPattern { @@ -627,7 +679,7 @@ class Conv2dOpConversionPattern : public OpConversionPattern { matchAndRewrite(ttir::Conv2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto device = getOrInsertDevice(rewriter, op); + auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op); auto kernel_ty = mlir::cast(adaptor.getWeight().getType()); llvm::ArrayRef kernel_shape = kernel_ty.getShape(); @@ -725,7 +777,7 @@ class MaxPool2dOpConversionPattern "TTNN max_pool2d does not support padding top/bottom/left/right " "separately"); - auto device = getOrInsertDevice(rewriter, op); + auto device = mlir::tt::ttnn::utils::getOrInsertDevice(rewriter, op); auto input_ty = mlir::cast(adaptor.getInput().getType()); llvm::ArrayRef input_shape = input_ty.getShape(); @@ -792,14 +844,8 @@ class TypecastOpConversionPattern ttnn::TTNNLayoutAttr outputLayoutAttr = mlir::cast(result.getType().getEncoding()); - mlir::MemRefType outputMemref = outputLayoutAttr.getMemref(); + DataType outputDataType = outputLayoutAttr.getDataType(); - DataType outputDataType = ttnn::utils::getDataTypeFromMemRef(outputMemref); - - if (op->getUsers().empty()) { - return rewriter.notifyMatchFailure( - op, "ttir.typecast op should have at least one use."); - } rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType(0)), input, outputDataType); @@ -807,46 +853,6 @@ class TypecastOpConversionPattern } }; -class BroadcastOpConversionPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - -public: - LogicalResult - matchAndRewrite(ttir::BroadcastOp srcOp, ttir::BroadcastOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - // Fold this operation into all consumer ops. It will only work with TTNN - // ops that support implicit broadcasting. We expect each Op's verify - // function to assert their arguments to verify that they can broadcast. - - if (srcOp->getUsers().empty()) { - // This broadcast chain has already been replaced. - rewriter.eraseOp(srcOp); - return success(); - } - - mlir::Value input = srcOp.getOperand(0); - - mlir::Operation *nextOp = srcOp; - while (isa(*nextOp->getUsers().begin())) { - assert(nextOp->hasOneUse() && - "Broadcast with multiple uses are not supported"); - nextOp = *nextOp->getUsers().begin(); - if (nextOp->getUsers().empty()) { - // This broadcast chain has already been replaced. - rewriter.eraseOp(srcOp); - return success(); - } - } - - rewriter.replaceAllOpUsesWith(nextOp, input); - rewriter.eraseOp(srcOp); - - return success(); - } -}; - class SubtractOpConversionPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -871,7 +877,7 @@ class SubtractOpConversionPattern // addOp(lhs, negOp(rhs)) } else { - Value device = getOrInsertDevice(rewriter, srcOp); + Value device = ::ttnn::utils::getOrInsertDevice(rewriter, srcOp); tensor::EmptyOp negEmptyOp = rewriter.create( srcOp.getLoc(), this->getTypeConverter()->convertType(rhsType), device); @@ -895,18 +901,68 @@ class AllGatherOpConversionPattern LogicalResult matchAndRewrite(ttir::AllGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType type = - mlir::cast(adaptor.getInput().getType()); - Value device = getOrInsertDevice(rewriter, op); - tensor::EmptyOp emptyOp = rewriter.create( - op.getLoc(), this->getTypeConverter()->convertType(type), device); - rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType()), emptyOp, - adaptor.getDim()); + op, this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), adaptor.getDim()); + return success(); + } +}; + +class ArangeOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::ArangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + RankedTensorType outputType = + mlir::cast(op.getResult().getType()); + assert(static_cast(adaptor.getArangeDimension()) == + outputType.getRank() - 1 && + "Arange dimension must be the final dimension of the output tensor " + "to convert to ttnn.arange"); + + // Get ttnn::TTNNLayoutAttr of the result type + // + ttnn::TTNNLayoutAttr layoutAttr = + mlir::cast(outputType.getEncoding()); + + DataTypeAttr dtypeAttr = rewriter.getAttr( + elementTypeToDataType(outputType.getElementType())); + Value device = mlir::tt::ttnn::utils::getOrInsertDevice(rewriter, op); + + ttnn::MemoryConfigAttr memConfigAttr = + rewriter.getAttr( + rewriter.getAttr(layoutAttr.getBufferType()), + rewriter.getAttr( + rewriter.getAttr(layoutAttr.getShardShape())), + layoutAttr.getMemLayout()); + + rewriter.replaceOpWithNewOp( + op, outputType, adaptor.getStart(), adaptor.getEnd(), adaptor.getStep(), + dtypeAttr, device, memConfigAttr); + + return success(); + } +}; + +class ScatterOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::ScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The ttnn interface has the inverse inputs of the TTIR dialect op (which + // matches torch ops). + rewriter.replaceOpWithNewOp( + op, adaptor.getUpdate(), adaptor.getInput(), adaptor.getOutput()); + return success(); } }; +} // namespace namespace mlir::tt { @@ -953,11 +1009,12 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, - ElementwiseUnaryWithFloatParameterOpConversionPattern, + ElementwiseOpConversionPattern, + ElementwiseOpConversionPattern, ReductionOpConversionPattern, ReductionOpConversionPattern, ReductionOpConversionPattern, - BroadcastOpConversionPattern, + ElementwiseUnaryWithFloatParameterOpConversionPattern, EmbeddingOpConversionPattern, SoftmaxOpConversionPattern, TransposeOpConversionPattern, @@ -969,11 +1026,16 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, SqueezeOpConversionPattern, UnsqueezeOpConversionPattern, ConstantOpConversionPattern, + LinearOpConversionPattern, MatmulOpConversionPattern, Conv2dOpConversionPattern, MaxPool2dOpConversionPattern, SubtractOpConversionPattern, - AllGatherOpConversionPattern + AllGatherOpConversionPattern, + ArangeOpConversionPattern, + UpdateCacheOpConversionPattern, + FillCacheOpConversionPattern, + ScatterOpConversionPattern >(typeConverter, ctx); // ANCHOR_END: op_rewriter_pattern_set // clang-format on diff --git a/lib/Conversion/TTKernelToEmitC/CMakeLists.txt b/lib/Conversion/TTKernelToEmitC/CMakeLists.txt index 4ed57a5d41..429a694f31 100644 --- a/lib/Conversion/TTKernelToEmitC/CMakeLists.txt +++ b/lib/Conversion/TTKernelToEmitC/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_library(TTMLIRTTKernelToEmitC +add_mlir_conversion_library(TTMLIRTTKernelToEmitC TTKernelToEmitC.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp b/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp index 312377eb6e..c265e89283 100644 --- a/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp +++ b/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp @@ -406,8 +406,10 @@ class ConvertTTKernelToEmitCPass TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, @@ -419,6 +421,12 @@ class ConvertTTKernelToEmitCPass TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter< + ttkernel::NocAsyncWriteMulticastOnePacketOp>, + TTMetalToEmitCOpaqueRewriter, + TTMetalToEmitCOpaqueRewriter< + ttkernel::NocAsyncWriteMulticastLoopbackSrcOp>, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, TTMetalToEmitCOpaqueRewriter, @@ -473,6 +481,8 @@ class ThreadConfigHelper { builder->create(loc, "compute_kernel_api/eltwise_binary.h", /*isStandard=*/false); + builder->create(loc, "compute_kernel_api.h", // max ops + /*isStandard=*/false); builder->create(loc, "compute_kernel_api/tile_move_copy.h", /*isStandard=*/false); diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 9b7cf7fe84..aedad4d290 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -9,6 +9,7 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" @@ -86,8 +87,6 @@ emitc::OpaqueAttr convertTensorMemoryLayout(Builder &builder, case ttnn::TensorMemoryLayout::WidthSharded: return builder.getType( "ttnn::TensorMemoryLayout::WIDTH_SHARDED"); - case ttnn::TensorMemoryLayout::None: - llvm_unreachable("Unsupported ttnn::TensorMemoryLayout"); } } @@ -618,6 +617,58 @@ class DeallocateOpConversionPattern } }; +// arith::ConstantOp conversion pattern +// +class ArithConstantOpConversionPattern + : public OpConversionPattern { + +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp constOp, arith::ConstantOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Type newTy = this->getTypeConverter()->convertType(constOp.getType()); + if (!newTy) { + return rewriter.notifyMatchFailure(constOp, "type conversion failed"); + } + + rewriter.replaceOpWithNewOp(constOp, newTy, + adaptor.getValue()); + return success(); + } +}; + +// Module Op conversion pattern +// +// This conversion pattern removes attributes from the ModuleOp. Previously, +// ttmlir-translate would complain when translating to C++ if there were any +// attributes from "unregistered" dialects. +// +class ModuleOpConversionPattern + : public TTNNToEmitCBaseOpConversionPattern { + +public: + ModuleOpConversionPattern(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : TTNNToEmitCBaseOpConversionPattern(typeConverter, + context, benefit) {} + + LogicalResult + matchAndRewrite(mlir::ModuleOp srcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.modifyOpInPlace(srcOp, [&]() { + for (const NamedAttribute &attr : srcOp->getAttrs()) { + srcOp->removeAttr(attr.getName()); + } + }); + + return success(); + } +}; + } // namespace namespace mlir::tt { @@ -639,8 +690,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Tensor ops // patterns - .add>( - typeConverter, ctx); + .add, + DefaultOpConversionPattern>(typeConverter, ctx); // Eltwise unary ops // @@ -665,6 +716,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, + DefaultOpConversionPattern, + DefaultOpConversionPattern, DefaultOpConversionPattern>(typeConverter, ctx); // Eltwise binary ops @@ -684,6 +737,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, + DefaultOpConversionPattern, DefaultOpConversionPattern>(typeConverter, ctx); @@ -696,7 +750,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Matmul ops // - patterns.add>(typeConverter, ctx); + patterns.add, + DefaultOpConversionPattern>(typeConverter, ctx); // Reduction ops // @@ -720,6 +775,21 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // patterns.add>(typeConverter, ctx); + + // Module op + // + patterns.add(typeConverter, ctx); + + // KV Cache ops + // + patterns.add>(typeConverter, + ctx); + patterns.add>(typeConverter, + ctx); + + // Arith ops + // + patterns.add(typeConverter, ctx); } } // namespace mlir::tt diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp index 71a7c52b60..bd0c9044fc 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp @@ -4,6 +4,11 @@ #include "ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h" +#include "ttmlir/Dialect/TTNN/IR/TTNN.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" + #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" @@ -12,11 +17,6 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include "ttmlir/Dialect/TTNN/IR/TTNN.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" - using namespace mlir; using namespace mlir::tt; @@ -48,14 +48,20 @@ struct ConvertTTNNToEmitCPass void runOnOperation() override { mlir::ConversionTarget target(getContext()); + // EmitC is legal, TTNN is illegal + // target.addLegalDialect(); target.addIllegalDialect(); - target.addLegalOp(); + + // mlir::ModuleOp is legal only if no attributes are present on it + // + target.addDynamicallyLegalOp( + [&](mlir::ModuleOp op) { return op->getAttrs().empty(); }); // Add header imports to front of module // { - auto module = getOperation(); + mlir::ModuleOp module = getOperation(); OpBuilder builder(module); if (module.getBodyRegion().empty()) { @@ -107,7 +113,7 @@ struct ConvertTTNNToEmitCPass return; } } - }; + } }; } // namespace diff --git a/lib/Conversion/TosaToTTIR/CMakeLists.txt b/lib/Conversion/TosaToTTIR/CMakeLists.txt index 41baf75c67..56000eb652 100644 --- a/lib/Conversion/TosaToTTIR/CMakeLists.txt +++ b/lib/Conversion/TosaToTTIR/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(TTMLIRTosaToTTIR - TosaToTTIR.cpp + TosaToTTIRPass.cpp + TosaToTTIRPatterns.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/TosaToTTIR diff --git a/lib/Conversion/TosaToTTIR/TosaToTTIR.cpp b/lib/Conversion/TosaToTTIR/TosaToTTIR.cpp deleted file mode 100644 index 6c6a7faf56..0000000000 --- a/lib/Conversion/TosaToTTIR/TosaToTTIR.cpp +++ /dev/null @@ -1,122 +0,0 @@ -// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h" -#include "ttmlir/Dialect/TT/IR/TT.h" -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTIR/IR/TTIR.h" -#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/FuncConversions.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" - -using namespace mlir; -using namespace tt; - -namespace mlir::tt::ttir { - -#define GEN_PASS_DEF_CONVERTTOSATOTTIR -#include "ttmlir/Conversion/Passes.h.inc" - -} // namespace mlir::tt::ttir - -namespace { - -template -class TosaToTTIROpConversionPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - -public: - LogicalResult - matchAndRewrite(SrcOp srcOp, Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if constexpr (std::is_same::value) { - assert(srcOp.getShift() == 0); - } - - auto outputType = mlir::cast(srcOp.getResult().getType()); - auto outputTensor = rewriter.create( - srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - rewriter.replaceOpWithNewOp( - srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands(), - ValueRange(outputTensor), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); - return success(); - } -}; - -struct ConvertTosaToTTIRPass - : public ttir::impl::ConvertTosaToTTIRBase { - void runOnOperation() override { - mlir::ConversionTarget target(getContext()); - - target.addIllegalDialect(); - - target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - - // For now keep the same type assuming tosa ops operate on builtin tensor. - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { - assert(isa(type) && - "only ranked tensor type supported"); - return type; - }); - RewritePatternSet patterns(&getContext()); - - // Add conversion patterns. - patterns - .add>( - typeConverter, &getContext()); - patterns - .add>( - typeConverter, &getContext()); - patterns.add< - TosaToTTIROpConversionPattern>( - typeConverter, &getContext()); - patterns.add< - TosaToTTIROpConversionPattern>( - typeConverter, &getContext()); - patterns.add< - TosaToTTIROpConversionPattern>( - typeConverter, &getContext()); - patterns.add>( - typeConverter, &getContext()); - - // Apply conversion. - if (failed( - applyFullConversion(getOperation(), target, std::move(patterns)))) { - signalPassFailure(); - return; - } - } -}; - -} // namespace - -namespace mlir::tt { - -std::unique_ptr> createConvertTosaToTTIRPass() { - return std::make_unique(); -} - -} // namespace mlir::tt diff --git a/lib/Conversion/TosaToTTIR/TosaToTTIRPass.cpp b/lib/Conversion/TosaToTTIR/TosaToTTIRPass.cpp new file mode 100644 index 0000000000..183d58ccaa --- /dev/null +++ b/lib/Conversion/TosaToTTIR/TosaToTTIRPass.cpp @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h" +#include "ttmlir/Dialect/TTIR/IR/TTIR.h" + +using namespace mlir; +using namespace mlir::tt; + +namespace mlir::tt::ttir { + +#define GEN_PASS_DEF_CONVERTTOSATOTTIR +#include "ttmlir/Conversion/Passes.h.inc" + +} // namespace mlir::tt::ttir + +namespace { + +struct ConvertTosaToTTIRPass + : public ttir::impl::ConvertTosaToTTIRBase { + void runOnOperation() override { + mlir::ConversionTarget target(getContext()); + + target.addIllegalDialect(); + + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + // For now keep the same type assuming tosa ops operate on builtin tensor. + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { + assert(isa(type) && + "only ranked tensor type supported"); + return type; + }); + RewritePatternSet patterns(&getContext()); + + // Add conversion patterns. + populateTosaToTTIRPatterns(&getContext(), patterns, typeConverter); + + // Apply conversion. + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); + return; + } + } +}; + +} // namespace + +namespace mlir::tt { + +std::unique_ptr> createConvertTosaToTTIRPass() { + return std::make_unique(); +} + +} // namespace mlir::tt diff --git a/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp b/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp new file mode 100644 index 0000000000..8d4e4caafd --- /dev/null +++ b/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp @@ -0,0 +1,327 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" + +using namespace mlir; +using namespace mlir::tt; + +namespace { + +// TODO(sdjukic): extract this pattern into separate file and use it for both +// TOSA and StableHLO + +template +class TosaToTTIRDefaultDPSOpConversionPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(SrcOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + LogicalResult legalityResult = + checkConversionLegality(srcOp, adaptor, rewriter); + if (!legalityResult.succeeded()) { + return legalityResult; + } + + RankedTensorType outputType = + mlir::cast(srcOp.getResult().getType()); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + rewriter.replaceOpWithNewOp( + srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands(), + ValueRange(outputTensor), + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + return success(); + } + +private: + virtual LogicalResult + checkConversionLegality(SrcOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + return success(); + } +}; + +class TosaToTTIRMultiplyOpConversionPattern + : public TosaToTTIRDefaultDPSOpConversionPattern< + tosa::MulOp, mlir::tt::ttir::MultiplyOp> { + using TosaToTTIRDefaultDPSOpConversionPattern< + tosa::MulOp, + mlir::tt::ttir::MultiplyOp>::TosaToTTIRDefaultDPSOpConversionPattern; + +private: + LogicalResult + checkConversionLegality(tosa::MulOp srcOp, tosa::MulOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (srcOp.getShift() != 0) { + return rewriter.notifyMatchFailure( + srcOp, "TTIR MultiplyOp doesn't support shifted multiply."); + } + return success(); + } +}; + +class TosaToTTIRClampOpConversionPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(tosa::ClampOp srcOp, tosa::ClampOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType outputType = + mlir::cast(srcOp.getResult().getType()); + + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + rewriter.replaceOpWithNewOp( + srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands()[0], + outputTensor, adaptor.getMinFp(), adaptor.getMaxFp(), + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + return success(); + } +}; + +class TosaToTTIRMatmulOpConversionPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using Adaptor = tosa::MatMulOp::Adaptor; + +public: + LogicalResult + matchAndRewrite(tosa::MatMulOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LogicalResult legalityResult = + checkConversionLegality(srcOp, adaptor, rewriter); + if (!legalityResult.succeeded()) { + return legalityResult; + } + RankedTensorType outputType = + mlir::cast(srcOp.getResult().getType()); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + ValueRange operands = adaptor.getOperands(); + + rewriter.replaceOpWithNewOp( + srcOp, TypeRange(outputTensor.getType()), operands[0], operands[1], + outputTensor, + + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + return success(); + } + +private: + LogicalResult + checkConversionLegality(tosa::MatMulOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (srcOp.getQuantizationInfo().has_value()) { + return rewriter.notifyMatchFailure( + srcOp, "TTIR MatmulOp currently doesn't support quantization."); + } + return success(); + } +}; + +template +class TosaToTTIRReduceOpConversionPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(SrcOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType outputType = + mlir::cast(srcOp.getResult().getType()); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + rewriter.replaceOpWithNewOp( + srcOp, outputTensor.getType(), adaptor.getInput(), outputTensor, + true /*keepdim*/, + rewriter.getArrayAttr(SmallVector(1, adaptor.getAxisAttr())), + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + return success(); + } +}; + +class TosaToTTIRMaxPool2DOpConversionPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using Adaptor = tosa::MaxPool2dOp::Adaptor; + +public: + LogicalResult + matchAndRewrite(tosa::MaxPool2dOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outputType = mlir::cast(srcOp.getResult().getType()); + auto outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + auto dims = srcOp.getKernelAttr(); + auto strides = srcOp.getStrideAttr(); + auto pad = srcOp.getPadAttr(); + rewriter.replaceOpWithNewOp( + srcOp, TypeRange(outputTensor.getType()), adaptor.getInput(), + outputTensor, dims[0], dims[1], strides[0], strides[1], 1, 1, false, + pad[2], pad[3], pad[0], pad[1], + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + return success(); + } +}; + +void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>( + typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>( + typeConverter, ctx); +} + +void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>( + typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); +} + +void addElementwiseTernaryOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>(typeConverter, ctx); +} + +void addLogicalOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); +} + +void addCompareOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, + ctx); + patterns.add>(typeConverter, ctx); +} + +void addMatmulOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + +void addReductionOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); +} + +void addPoolingOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} +} // namespace + +namespace mlir::tt { + +void populateTosaToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + addElementwiseUnaryOpsConversionPatterns(ctx, patterns, typeConverter); + addElementwiseBinaryOpsConversionPatterns(ctx, patterns, typeConverter); + addElementwiseTernaryOpsConversionPatterns(ctx, patterns, typeConverter); + addLogicalOpsConversionPatterns(ctx, patterns, typeConverter); + addCompareOpsConversionPatterns(ctx, patterns, typeConverter); + addMatmulOpsConversionPatterns(ctx, patterns, typeConverter); + addReductionOpsConversionPatterns(ctx, patterns, typeConverter); + addPoolingOpsConversionPatterns(ctx, patterns, typeConverter); + + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::tt diff --git a/lib/Dialect/TT/IR/TTDialect.cpp b/lib/Dialect/TT/IR/TTDialect.cpp index 6f629d6977..1ac8a22239 100644 --- a/lib/Dialect/TT/IR/TTDialect.cpp +++ b/lib/Dialect/TT/IR/TTDialect.cpp @@ -13,13 +13,13 @@ using namespace mlir; using namespace mlir::tt; -// This is needed to hoist tt.layout attributes as named attributes declared at -// the module level. +// This is needed to hoist tt.metal_layout attributes as named attributes +// declared at the module level. struct TTOpAsmDialectInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { - if (llvm::isa(attr)) { + if (llvm::isa(attr)) { os << "layout"; return AliasResult::OverridableAlias; } diff --git a/lib/Dialect/TT/IR/TTOpsTypes.cpp b/lib/Dialect/TT/IR/TTOpsTypes.cpp index bbdd4e2590..12166e4433 100644 --- a/lib/Dialect/TT/IR/TTOpsTypes.cpp +++ b/lib/Dialect/TT/IR/TTOpsTypes.cpp @@ -466,7 +466,7 @@ calculateLogicalShardShape(mlir::ArrayRef tensorShape, return shardShape; } -LayoutAttr LayoutAttr::get( +MetalLayoutAttr MetalLayoutAttr::get( ::mlir::MLIRContext *context, ArrayRef tensorShape, Type elementType, MemorySpace memorySpace, GridAttr grid, ArrayRef> collapseIntervals, @@ -483,7 +483,7 @@ LayoutAttr LayoutAttr::get( return get(context, linear, oobVal, grid, memref, memLayout); } -LayoutAttr LayoutAttr::get( +MetalLayoutAttr MetalLayoutAttr::get( ::mlir::MLIRContext *context, RankedTensorType ty, MemorySpace memorySpace, GridAttr grid, ArrayRef> collapseIntervals, @@ -493,9 +493,11 @@ LayoutAttr LayoutAttr::get( collapseIntervals, oobVal, memLayout); } -LayoutAttr LayoutAttr::get(::mlir::MLIRContext *context, RankedTensorType ty, - MemorySpace memorySpace, GridAttr grid, - Type elementType, TensorMemoryLayout memLayout) { +MetalLayoutAttr MetalLayoutAttr::get(::mlir::MLIRContext *context, + RankedTensorType ty, + MemorySpace memorySpace, GridAttr grid, + Type elementType, + TensorMemoryLayout memLayout) { assert(ty); assert(grid); return get(context, ty.getShape(), elementType, memorySpace, grid, {{0, -1}}, @@ -506,7 +508,7 @@ LayoutAttr LayoutAttr::get(::mlir::MLIRContext *context, RankedTensorType ty, // compute the physical shape of the tensor, i.e the shape of the tensor // after the dimensions have been collapsed onto a grid. llvm::SmallVector -LayoutAttr::getPhysicalShape(ArrayRef logicalShape) const { +MetalLayoutAttr::getPhysicalShape(ArrayRef logicalShape) const { llvm::SmallVector physicalShape(getGrid().getShape().size()); SmallVector logicalShapeExprs( llvm::map_range(logicalShape, [context = getContext()](std::int64_t e) { @@ -525,7 +527,7 @@ LayoutAttr::getPhysicalShape(ArrayRef logicalShape) const { } llvm::SmallVector -LayoutAttr::getStride(ArrayRef logicalShape) const { +MetalLayoutAttr::getStride(ArrayRef logicalShape) const { llvm::SmallVector stride(logicalShape.size()); @@ -574,7 +576,7 @@ LayoutAttr::getStride(ArrayRef logicalShape) const { } llvm::SmallVector -LayoutAttr::getShardShape(bool convertTileToScalar) const { +MetalLayoutAttr::getShardShape(bool convertTileToScalar) const { SmallVector shardShape(getMemref().getShape()); auto elementType = getElementType(); if (mlir::isa(elementType) && convertTileToScalar) { @@ -583,11 +585,11 @@ LayoutAttr::getShardShape(bool convertTileToScalar) const { return shardShape; } -mlir::Type LayoutAttr::getElementType() const { +mlir::Type MetalLayoutAttr::getElementType() const { return getMemref().getElementType(); } -mlir::Type LayoutAttr::getScalarElementType() const { +mlir::Type MetalLayoutAttr::getScalarElementType() const { auto elementType = getElementType(); if (mlir::isa(elementType)) { return mlir::cast(elementType).getElementType(); @@ -595,33 +597,33 @@ mlir::Type LayoutAttr::getScalarElementType() const { return elementType; } -bool LayoutAttr::hasShardedTensorMemoryLayout() const { +bool MetalLayoutAttr::hasShardedTensorMemoryLayout() const { return (getMemLayout() == TensorMemoryLayout::HeightSharded or getMemLayout() == TensorMemoryLayout::WidthSharded or getMemLayout() == TensorMemoryLayout::BlockSharded); } -bool LayoutAttr::hasInterleavedTensorMemoryLayout() const { +bool MetalLayoutAttr::hasInterleavedTensorMemoryLayout() const { return (getMemLayout() == TensorMemoryLayout::Interleaved); } -bool LayoutAttr::hasShardedL1TensorMemoryLayout() const { +bool MetalLayoutAttr::hasShardedL1TensorMemoryLayout() const { return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and (getMemLayout() == TensorMemoryLayout::HeightSharded or getMemLayout() == TensorMemoryLayout::WidthSharded or getMemLayout() == TensorMemoryLayout::BlockSharded); } -bool LayoutAttr::hasInterleavedL1TensorMemoryLayout() const { +bool MetalLayoutAttr::hasInterleavedL1TensorMemoryLayout() const { return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and (getMemLayout() == TensorMemoryLayout::Interleaved); } -bool LayoutAttr::isTiled() const { +bool MetalLayoutAttr::isTiled() const { return ::mlir::isa<::mlir::tt::TileType>(getElementType()); } -uint64_t LayoutAttr::getElementSizeBytes() const { +uint64_t MetalLayoutAttr::getElementSizeBytes() const { mlir::Type elementType = getElementType(); if (mlir::isa(elementType)) { auto tileType = mlir::cast(elementType); @@ -630,7 +632,7 @@ uint64_t LayoutAttr::getElementSizeBytes() const { return elementType.getIntOrFloatBitWidth() / 8; } -uint64_t LayoutAttr::getMemrefSizeBytes() const { +uint64_t MetalLayoutAttr::getMemrefSizeBytes() const { MemRefType ty = getMemref(); auto shape = ty.getShape(); uint64_t size = getElementSizeBytes(); @@ -638,57 +640,60 @@ uint64_t LayoutAttr::getMemrefSizeBytes() const { std::multiplies()); } -LayoutAttr LayoutAttr::withGrid( +MetalLayoutAttr MetalLayoutAttr::withGrid( ::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals) { return get(context, tensorShape, getElementType(), getMemorySpace(), grid, collapseIntervals, getOobVal(), getMemLayout()); } -LayoutAttr LayoutAttr::withGrid( +MetalLayoutAttr MetalLayoutAttr::withGrid( ::mlir::MLIRContext *context, RankedTensorType ty, GridAttr grid, ArrayRef> collapseIntervals) { assert(ty); - return LayoutAttr::withGrid(context, ty.getShape(), grid, collapseIntervals); + return MetalLayoutAttr::withGrid(context, ty.getShape(), grid, + collapseIntervals); } -LayoutAttr LayoutAttr::withElementType(::mlir::MLIRContext *context, - Type elementType) { - return LayoutAttr::get( +MetalLayoutAttr MetalLayoutAttr::withElementType(::mlir::MLIRContext *context, + Type elementType) { + return MetalLayoutAttr::get( context, getLinear(), getOobVal(), getGrid(), buildMemRef(context, getShardShape(), elementType, getMemorySpace()), getMemLayout()); } -LayoutAttr LayoutAttr::withMemorySpace(::mlir::MLIRContext *context, - MemorySpace memorySpace) { - return LayoutAttr::get( +MetalLayoutAttr MetalLayoutAttr::withMemorySpace(::mlir::MLIRContext *context, + MemorySpace memorySpace) { + return MetalLayoutAttr::get( context, getLinear(), getOobVal(), getGrid(), buildMemRef(context, getShardShape(), getElementType(), memorySpace), getMemLayout()); } -LayoutAttr LayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, - TensorMemoryLayout memLayout) { - return LayoutAttr::get( +MetalLayoutAttr +MetalLayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, + TensorMemoryLayout memLayout) { + return MetalLayoutAttr::get( context, getLinear(), getOobVal(), getGrid(), buildMemRef( context, getShardShape(), getElementType(), getMemorySpace()), memLayout); } -LayoutAttr LayoutAttr::withShardShape(::mlir::MLIRContext *context, - llvm::SmallVector shardShape) { - return LayoutAttr::get( +MetalLayoutAttr +MetalLayoutAttr::withShardShape(::mlir::MLIRContext *context, + llvm::SmallVector shardShape) { + return MetalLayoutAttr::get( context, getLinear(), getOobVal(), getGrid(), buildMemRef( context, shardShape, getElementType(), getMemorySpace()), getMemLayout()); } -MemorySpace LayoutAttr::getMemorySpace() const { +MemorySpace MetalLayoutAttr::getMemorySpace() const { return mlir::cast(getMemref().getMemorySpace()) .getValue(); } @@ -696,7 +701,7 @@ MemorySpace LayoutAttr::getMemorySpace() const { // Returns shape of the tensor after tilization is applied to the two inner most // dimensions. llvm::SmallVector -LayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { +MetalLayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { assert(isTiled() && "Expected a tiled layout"); mlir::AffineMap linear = getLinear(); @@ -716,7 +721,7 @@ LayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { return ttmlir::utils::evalShape(tiled, tensorShape); } -mlir::AffineMap LayoutAttr::getIdentityTileLinearMap() const { +mlir::AffineMap MetalLayoutAttr::getIdentityTileLinearMap() const { assert(isTiled() && "Expected a tiled layout"); return mlir::AffineMap::getMultiDimIdentityMap(getLinear().getNumResults(), @@ -735,7 +740,7 @@ mlir::AffineMap LayoutAttr::getIdentityTileLinearMap() const { // (d0, d1)[2, 3] -> // (0, d0 floordiv 2, d1 floordiv 3, (d0 mod 2) * 3 + d1 mod 3) // -mlir::AffineMap LayoutAttr::replaceMemoryMapSymbolsWithShardShape( +mlir::AffineMap MetalLayoutAttr::replaceMemoryMapSymbolsWithShardShape( AffineMap physicalMemoryMap) const { mlir::SmallVector shardShape = getShardShape(false /*convertTileToScalar*/); @@ -763,8 +768,8 @@ mlir::AffineMap LayoutAttr::replaceMemoryMapSymbolsWithShardShape( // grid. Then it composes the logical grid projection with physical memory // mapping. mlir::AffineMap -LayoutAttr::projectOnto(mlir::AffineMap linearMap, - mlir::AffineMap physicalMemoryMap) const { +MetalLayoutAttr::projectOnto(mlir::AffineMap linearMap, + mlir::AffineMap physicalMemoryMap) const { assert(getGrid().getShape().size() == physicalMemoryMap.getNumDims() && "Layout and device grids must have same number of dimensions"); assert(getLinear().getNumResults() == physicalMemoryMap.getNumDims() && @@ -1013,7 +1018,7 @@ DeviceAttr DeviceAttr::get(::mlir::MLIRContext *context, // Sample the last index in the tensor to get the last addressable element of // the tensor to determine its footprint in memory. uint64_t DeviceAttr::getLayoutSizeBytes(ArrayRef tensorScalarShape, - LayoutAttr layout, + MetalLayoutAttr layout, MemorySpace memorySpace) const { SmallVector shape = layout.isTiled() ? layout.getTiledShape(tensorScalarShape) @@ -1035,9 +1040,9 @@ uint64_t DeviceAttr::getLayoutSizeBytes(ArrayRef tensorScalarShape, uint64_t DeviceAttr::getTensorSizeBytes(RankedTensorType tensorType, MemorySpace memorySpace) const { assert(tensorType.getEncoding()); - return getLayoutSizeBytes(tensorType.getShape(), - mlir::cast(tensorType.getEncoding()), - memorySpace); + return getLayoutSizeBytes( + tensorType.getShape(), + mlir::cast(tensorType.getEncoding()), memorySpace); } ::mlir::LogicalResult diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index ec415e090b..44af2f2c4b 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -45,6 +45,37 @@ ::mlir::LogicalResult mlir::tt::ttir::ClampOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ArangeOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::tt::ttir::ArangeOp::verify() { + int64_t start = getStart(); + int64_t end = getEnd(); + int64_t step = getStep(); + + if (step == 0) { + return emitOpError("Step value cannot be zero"); + } + + int64_t numValues = (end - start) / step; + + if (numValues <= 0) { + return emitOpError() << "Invalid range: start=" << start << ", end=" << end + << ", step=" << step; + } + + if (numValues != getType().getDimSize(getArangeDimension())) { + return emitOpError() << "Output tensor shape must be " << numValues + << " at dim " << getArangeDimension() + << " (since start=" << start << ", end=" << end + << ", step=" << step << "), but got " + << getType().getDimSize(getArangeDimension()); + } + + return success(); +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// @@ -307,11 +338,6 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() { return emitOpError("Shape attribute must be non-empty"); } - // Check that the shape attribute has at most 5 elements - if (shape_size > 5) { - return emitOpError("Shape attribute must have at most 5 elements"); - } - // Cardinality of the input and output tensors must be the same if (inputType.getNumElements() != outputType.getNumElements()) { return emitOpError( @@ -358,6 +384,15 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() { return success(); } +// ReshapeOp folder +::mlir::OpFoldResult mlir::tt::ttir::ReshapeOp::fold(FoldAdaptor adaptor) { + + if (getType() == getOperand(0).getType()) { + return getOperand(0); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// @@ -596,6 +631,100 @@ ::mlir::LogicalResult mlir::tt::ttir::IndexOp::verify() { } // ANCHOR_END: decomposing_an_op_index_ttir_verify +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +// SelectOp verification +::mlir::LogicalResult mlir::tt::ttir::SelectOp::verify() { + ::mlir::RankedTensorType inputType = getInput().getType(); + ::mlir::RankedTensorType outputType = getOutput().getType(); + + if (inputType.getRank() != outputType.getRank()) { + return emitOpError("Input and output tensors must have the same rank."); + } + + if (inputType.getElementType() != outputType.getElementType()) { + return emitOpError("Input and output tensors must have the same element " + "type."); + } + + int32_t dim = getDim(); + int32_t origDim = dim; + if (dim < 0) { + dim += inputType.getRank(); + } + + if (dim < 0 || dim >= inputType.getRank()) { + return emitOpError() << "Invalid dimension " << origDim + << " for select op with input tensor rank " + << inputType.getRank(); + } + + int32_t dimSize = inputType.getDimSize(dim); + + int32_t stride = getStride(); + if (stride == 0) { + stride = dimSize; + } + + if (stride < 0) { + return emitOpError() << "Invalid stride " << stride << " for dimension " + << dim << ", stride must be non-negative"; + } + + if (stride > dimSize) { + return emitOpError() << "Invalid stride " << stride << " for dimension " + << dim << " with size " << dimSize + << ". stride must be less than or equal to the " + "dimension size"; + } + + int32_t begin = getBegin(); + int32_t length = getLength(); + if (begin < 0 || begin >= dimSize) { + return emitOpError() << "Invalid begin index " << begin << " for dimension " + << dim << " with size " << dimSize + << ". begin must be " + "in the range [0, dimSize)"; + } + + if (length < 1 || length > stride) { + return emitOpError() << "Invalid length " << length << " for begin index " + << begin << " and stride " << stride + << " for dimension " << dim << " with size " << dimSize + << ". stride must be greater than or equal to length"; + } + + if (begin + length > dimSize) { + return emitOpError() << "Invalid length " << length << " for begin index " + << begin << " and dimension " << dim << " with size " + << dimSize + << ". begin + length must be less than or " + "equal to the dimension size"; + } + + // Get the number of slices as the number of times the stride fits in the + // dimension size starting from the begin index. + int32_t numSlices = (dimSize - begin + stride - 1) / stride; + int32_t totalLength = 0; + for (int32_t i = 0; i < numSlices; i++) { + int32_t newBegin = begin + i * stride; + int32_t newEnd = std::min(newBegin + length, dimSize); + totalLength += newEnd - newBegin; + } + + if (totalLength != outputType.getDimSize(dim)) { + return emitOpError() << "Sum of all slices must be equal to the output " + "dimension size for the given dimension. Expected " + "output dimension size: " + << outputType.getDimSize(dim) << ", but got " + << totalLength; + } + + return success(); +} + //===----------------------------------------------------------------------===// // SqueezeOp //===----------------------------------------------------------------------===// @@ -783,9 +912,9 @@ ::mlir::LogicalResult mlir::tt::ttir::ToLayoutOp::verify() { mlir::tt::ttir::ToLayoutOp::CompoundComponents mlir::tt::ttir::ToLayoutOp::compoundComponents() { auto inputLayout = - mlir::cast(getInput().getType().getEncoding()); + mlir::cast(getInput().getType().getEncoding()); auto outputLayout = - mlir::cast(getOutput().getType().getEncoding()); + mlir::cast(getOutput().getType().getEncoding()); bool isLayoutChange = inputLayout.getLinear() != outputLayout.getLinear(); bool isGridChange = inputLayout.getGrid() != outputLayout.getGrid(); bool isShardChange = @@ -801,6 +930,158 @@ mlir::tt::ttir::ToLayoutOp::compoundComponents() { isMemoryLayoutChange}; } +//===----------------------------------------------------------------------===// +// LinearOp +//===----------------------------------------------------------------------===// + +// LinearOp verification +::mlir::LogicalResult mlir::tt::ttir::LinearOp::verify() { + ::mlir::RankedTensorType inputAType = getA().getType(); + ::mlir::RankedTensorType inputBType = getB().getType(); + std::optional<::mlir::RankedTensorType> biasType = + getBias() ? std::make_optional(getBias().getType()) : std::nullopt; + ::mlir::RankedTensorType outputType = getOutput().getType(); + + llvm::ArrayRef outputShape = outputType.getShape(); + llvm::SmallVector inputAShape(inputAType.getShape()); + llvm::SmallVector inputBShape(inputBType.getShape()); + + // Verify that the input A is at least 1D tensor. + if (inputAType.getRank() < 1) { + return emitOpError("Input A must be at least a 1D tensor"); + } + + // Verify that the input B is at least 1D tensor. + if (inputBType.getRank() < 1) { + return emitOpError("Input B must be at least a 1D tensor"); + } + + // If input A is a vector (1D tensor), 1 is prepended to its dimension for the + // purpose of the matrix multiplication. After the matrix multiplication, the + // prepended dimension is removed. + if (inputAType.getRank() == 1) { + inputAShape.insert(inputAShape.begin(), 1); + } + + // If input B is a vector (1D tensor), a 1 is appended to its dimension for + // the purpose of the matrix-vector product and removed afterwards. + if (inputBType.getRank() == 1) { + inputBShape.push_back(1); + } + + // Verify that the input A and input B has matching inner dimensions. + if (inputAShape[inputAShape.size() - 1] != + inputBShape[inputBShape.size() - 2]) { + return emitOpError( + "Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) + + ") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) + + ") must have matching inner dimensions"); + } + + llvm::SmallVector expectedOutputShape; + // Verify that the batch dimensions are broadcast compatible and construct the + // expected output shape. + if (inputAShape.size() > 2 || inputBShape.size() > 2) { + llvm::SmallVector inputABatchDims, inputBBatchDims; + + if (inputAShape.size() > 2) { + inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(), + inputAShape.end() - 2); + } + + if (inputBShape.size() > 2) { + inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(), + inputBShape.end() - 2); + } + + // Verify that the batch dimensions of input A and B are broadcast + // compatible. + llvm::SmallVector broadcastedShape; + if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims, + broadcastedShape)) { + + return emitOpError("Batch dimensions of input A(" + + ttmlir::utils::join(inputABatchDims, ",") + + ") and B(" + + ttmlir::utils::join(inputBBatchDims, ",") + + ") are not broadcast compatible"); + } + + // Insert the broadcasted batch dimensions in the expected output shape. + expectedOutputShape.insert(expectedOutputShape.begin(), + broadcastedShape.begin(), + broadcastedShape.end()); + } + + // Insert the input A and B inner dimensions in expected output shape. + // Consider the case where input A and B are vectors. In that case, + // the dimension 1 is ommited from the output shape. + if (inputAType.getRank() > 1) { + expectedOutputShape.push_back(inputAShape[inputAShape.size() - 2]); + } + + if (inputBType.getRank() > 1) { + expectedOutputShape.push_back(inputBShape[inputBShape.size() - 1]); + } + + if (biasType) { + // Verify that the input bias is at least 1D tensor. + if (biasType.value().getRank() < 1) { + return emitOpError("Bias must be at least a 1D tensor"); + } + + llvm::SmallVector biasShape(biasType.value().getShape()); + + // Verify that the dimensions of the matmul of A and B are broadcast + // compatible with input bias. + llvm::SmallVector matmulShape = expectedOutputShape; + if (!OpTrait::util::getBroadcastedShape(matmulShape, biasShape, + expectedOutputShape)) { + return emitOpError("Bias shape(" + ttmlir::utils::join(biasShape, ",") + + ") is not broadcast compatible with the matmul output " + "shape(" + + ttmlir::utils::join(matmulShape, ",") + ")"); + } + } + + // Check the case of a vector-vector product. At this moment we don't support + // scalars in IR, hence check that the output is at least 1D tensor of size 1. + if (expectedOutputShape.size() == 0) { + if (outputType.getRank() < 1) { + return emitOpError("Scalar output is not supported, output must be at " + "least a 1D tensor"); + } + + if (outputType.getRank() > 1 || outputType.getShape()[0] != 1) { + return emitOpError("Scalar output must be a 1D tensor of size 1"); + } + + return success(); + } + + // Verify that the output shape dimension count is correct. + if (outputShape.size() != expectedOutputShape.size()) { + return emitOpError("Output shape rank(" + + std::to_string(outputShape.size()) + + ") must match the expected output shape rank(" + + std::to_string(expectedOutputShape.size()) + ")"); + } + + // Verify each dim of the output shape. + for (size_t i = 0; i < outputShape.size(); i++) { + if (outputShape[i] != expectedOutputShape[i]) { + return emitOpError( + "Output shape dimension[" + std::to_string(i) + "](" + + std::to_string(outputShape[i]) + + ") doesn't match the expected output shape dimension[" + + std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) + + ")"); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // MatmulOp //===----------------------------------------------------------------------===// @@ -939,7 +1220,7 @@ ::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() { // AllocOp verification ::mlir::LogicalResult mlir::tt::ttir::AllocOp::verify() { - auto layout = mlir::dyn_cast_or_null( + auto layout = mlir::dyn_cast_or_null( getResult().getType().getEncoding()); if (not layout) { return emitOpError("Result type missing layout attribute"); @@ -1012,6 +1293,223 @@ ::mlir::LogicalResult mlir::tt::ttir::AllGatherOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AllReduceOp +//===----------------------------------------------------------------------===// + +// AllReduceOp verification +::mlir::LogicalResult mlir::tt::ttir::AllReduceOp::verify() { + ::mlir::RankedTensorType inputType = + mlir::cast(getInputs().front().getType()); + int32_t dim = getDim(); + + if (dim >= inputType.getRank()) { + return emitOpError("Invalid dimension for all_reduce op."); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// MeshShardOp +//===----------------------------------------------------------------------===// + +// MeshShardOp verification +::mlir::LogicalResult mlir::tt::ttir::MeshShardOp::verify() { + auto shardType = getShardType(); + + // currently we are only supporting replicate or devices from StableHLO + if (shardType != mlir::tt::MeshShardType::Replicate && + shardType != mlir::tt::MeshShardType::Devices) { + return emitOpError("Invalid shard_type for mesh_shard op."); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +bool matchSimpleBlock(mlir::Region ®ion) { + if (!region.hasOneBlock()) { + return false; + } + mlir::Block &block = region.front(); + if (block.getNumArguments() != 2) { + return false; + } + auto argType1 = + mlir::cast(block.getArgument(0).getType()); + auto argType2 = + mlir::cast(block.getArgument(1).getType()); + if (!argType1 || !argType2) { + return false; + } + if (block.getOperations().size() != 1) { + return false; + } + mlir::tt::ttir::YieldOp returnOp = + mlir::cast(&block.front()); + if (!returnOp) { + return false; + } + if (returnOp.getNumOperands() != 1 || + returnOp.getOperand(0) != block.getArgument(1)) { + return false; + } + return true; +} + +::mlir::LogicalResult mlir::tt::ttir::ScatterOp::verify() { + + ArrayRef inputShape = + mlir::cast(getInput().getType()).getShape(); + + if (getUpdateWindowDims().size() + getInsertedWindowDims().size() != + inputShape.size()) { + return emitOpError("Batching currently not supported"); + } + + for (uint64_t insertedWindowDims : getInsertedWindowDims()) { + if (inputShape[insertedWindowDims] != 1) { + return emitOpError("Dimension size to slice into must be 1"); + } + } + + // We currently do not support custom functions in the scatter function, + // which is a possbility in StableHLO dialect. See issue: + // https://github.com/tenstorrent/tt-mlir/issues/1278 + if (!matchSimpleBlock(getUpdateComputation())) { + return emitOpError( + "Currently not supporting custom scatter function in TTNN " + "dialect and TT-metal."); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// UpdateCacheOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::tt::ttir::UpdateCacheOp::verify() { + if (getBatchOffset() != 0) { + return emitOpError( + "Only single-batch is supported. Batch offset must be 0"); + } + + const ::mlir::RankedTensorType cacheType = getCache().getType(); + const ::mlir::RankedTensorType inputType = getInput().getType(); + + const DataType cacheDataType = + elementTypeToDataType(cacheType.getElementType()); + const DataType inputDataType = + elementTypeToDataType(inputType.getElementType()); + + if (cacheDataType != inputDataType) { + return emitOpError( + "Cache and input tensors must have the same dtype. " + "Got cache dtype = " + + DataTypeEnumToString(cacheDataType) + + ", input dtype = " + DataTypeEnumToString(inputDataType)); + } + + if (cacheType.getRank() != 4) { + return emitOpError("Cache tensor must be a 4D tensor"); + } + + if (inputType.getRank() != 4) { + return emitOpError("Input tensor must be a 4D tensor"); + } + + if (inputType.getShape()[2] != 1) { + return emitOpError("Input tensor requires that dim 2 have size 1, got " + "input dim 2 size = " + + std::to_string(inputType.getShape()[2])); + } + + if (cacheType.getShape()[0] != inputType.getShape()[0] || + cacheType.getShape()[1] != inputType.getShape()[1] || + cacheType.getShape()[3] != inputType.getShape()[3]) { + return emitOpError("Cache tensor shape must match input tensor shape on " + "all dimensions except dim 2. Got cache shape (" + + std::to_string(cacheType.getShape()[0]) + ", " + + std::to_string(cacheType.getShape()[1]) + ", " + + std::to_string(cacheType.getShape()[2]) + ", " + + std::to_string(cacheType.getShape()[3]) + + "), input shape ()" + + std::to_string(inputType.getShape()[0]) + "x" + + std::to_string(inputType.getShape()[1]) + "x" + + std::to_string(inputType.getShape()[2]) + "x" + + std::to_string(inputType.getShape()[3]) + ")"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// FillCacheOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::tt::ttir::FillCacheOp::verify() { + if (getBatchOffset() != 0) { + return emitOpError( + "Only single-batch is supported. Batch offset must be 0"); + } + + const ::mlir::RankedTensorType cacheType = getCache().getType(); + const ::mlir::RankedTensorType inputType = getInput().getType(); + + const DataType cacheDataType = + elementTypeToDataType(cacheType.getElementType()); + const DataType inputDataType = + elementTypeToDataType(inputType.getElementType()); + + if (cacheDataType != inputDataType) { + return emitOpError( + "Cache and input tensors must have the same dtype. " + "Got cache dtype = " + + DataTypeEnumToString(cacheDataType) + + ", input dtype = " + DataTypeEnumToString(inputDataType)); + } + + if (cacheType.getRank() != 4) { + return emitOpError("Cache tensor must be a 4D tensor"); + } + + if (inputType.getRank() != 4) { + return emitOpError("Input tensor must be a 4D tensor"); + } + + if (inputType.getShape()[2] > cacheType.getShape()[2]) { + return emitOpError( + "Input tensor requires that dim 2 have a size which is less than or " + "equal to the size of dim 2 of the cache tensor. Got cache dim 2 size " + "= " + + std::to_string(cacheType.getShape()[2]) + + ", input dim 2 size = " + std::to_string(inputType.getShape()[2])); + } + + if (cacheType.getShape()[0] != inputType.getShape()[0] || + cacheType.getShape()[1] != inputType.getShape()[1] || + cacheType.getShape()[3] != inputType.getShape()[3]) { + return emitOpError("Cache tensor shape must match input tensor shape on " + "all dimensions except dim 2. Got cache shape (" + + std::to_string(cacheType.getShape()[0]) + ", " + + std::to_string(cacheType.getShape()[1]) + ", " + + std::to_string(cacheType.getShape()[2]) + ", " + + std::to_string(cacheType.getShape()[3]) + + "), input shape (" + + std::to_string(inputType.getShape()[0]) + ", " + + std::to_string(inputType.getShape()[1]) + ", " + + std::to_string(inputType.getShape()[2]) + ", " + + std::to_string(inputType.getShape()[3]) + ")"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // GenericOp //===----------------------------------------------------------------------===// @@ -1093,6 +1591,13 @@ void mlir::tt::ttir::DivOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, block); } +// MaximumOp generic region builder +void mlir::tt::ttir::MaximumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, + ::mlir::Block *block) { + buildGenericEltwiseBinaryRegion(getLoc(), opBuilder, + block); +} + //===----------------------------------------------------------------------===// // KernelOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp b/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp index 84409174a3..10619f24b8 100644 --- a/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp +++ b/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp @@ -17,37 +17,33 @@ #include "llvm/ADT/SmallVector.h" mlir::LogicalResult -mlir::tt::ttir::detail::verifyElementwiseOp(mlir::Operation *op) { +mlir::tt::ttir::detail::verifyBroadcastable(mlir::Operation *op) { + const auto getShape = [](const Value val) { + return mlir::cast(val.getType()).getShape(); + }; + + const auto operandSegmentSizes = + op->getAttrOfType("operandSegmentSizes"); + // DPS operands shouldn't affect the result shape. + const auto outputSegmentSize = + operandSegmentSizes[operandSegmentSizes.size() - 1]; + const auto operandShapes = llvm::map_range(op->getOperands(), getShape); llvm::SmallVector broadcastedShape; - mlir::OperandRange operands = op->getOperands(); - mlir::OperandRange::iterator operand_it = operands.begin(); - llvm::SmallVector prevOperandShape( - mlir::cast((*operand_it).getType()).getShape()); - - while (++operand_it != operands.end()) { - llvm::SmallVector nextOperandShape( - mlir::cast((*operand_it).getType()).getShape()); - - if (!OpTrait::util::getBroadcastedShape(prevOperandShape, nextOperandShape, + for (const auto operandShape : + llvm::drop_end(operandShapes, outputSegmentSize)) { + const auto prevBroadcastedShape = broadcastedShape; + if (!OpTrait::util::getBroadcastedShape(prevBroadcastedShape, operandShape, broadcastedShape)) { return op->emitOpError("Operands are not broadcast compatible"); } - prevOperandShape = broadcastedShape; } - llvm::SmallVector resultShape( - mlir::cast(op->getResult(0).getType()) - .getShape()); + // Check that the result shape matches the broadcasted shape of the operands. + llvm::SmallVector resultShape(getShape(op->getResults().front())); if (broadcastedShape != resultShape) { return op->emitOpError( "Result shape must match operand shapes after broadcasting"); } - TypeID expectedBaseTy = op->getResultTypes().front().getTypeID(); - if (!llvm::all_of(op->getOperandTypes(), - [&](Type t) { return t.getTypeID() == expectedBaseTy; })) { - return op->emitOpError() << "All operands/results must have the same type"; - } - return success(); } diff --git a/lib/Dialect/TTIR/Transforms/Allocate.cpp b/lib/Dialect/TTIR/Transforms/Allocate.cpp index 37e788385c..a643f041c3 100644 --- a/lib/Dialect/TTIR/Transforms/Allocate.cpp +++ b/lib/Dialect/TTIR/Transforms/Allocate.cpp @@ -22,13 +22,13 @@ inline MemorySpace getMemorySpace(MemRefType memref) { return mlir::cast(memref.getMemorySpace()).getValue(); } -inline MemorySpace getMemorySpace(LayoutAttr layout) { +inline MemorySpace getMemorySpace(MetalLayoutAttr layout) { return getMemorySpace(layout.getMemref()); } inline MemorySpace getMemorySpace(RankedTensorType ty) { assert(ty.getEncoding()); - auto layout = mlir::cast(ty.getEncoding()); + auto layout = mlir::cast(ty.getEncoding()); return getMemorySpace(layout); } diff --git a/lib/Dialect/TTIR/Transforms/Broadcast.cpp b/lib/Dialect/TTIR/Transforms/Broadcast.cpp new file mode 100644 index 0000000000..7823b021ed --- /dev/null +++ b/lib/Dialect/TTIR/Transforms/Broadcast.cpp @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + +namespace mlir::tt::ttir { +#define GEN_PASS_DEF_TTIRBROADCASTFOLD +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" + +//===----------------------------------------------------------------------===// +// Broadcast Folding pass +// Our backend supports implicit broadcast of operands, so explicit broadcast +// instructions are folded. +// +// For Example: +// +// %0 = tensor.empty() : tensor<512xf32> +// %1 = "ttir.broadcast"(%arg0, %0) (tensor<1xf32>, tensor<512xf32>) -> +// tensor<512xf32> %2 = tensor.empty() : tensor<512xf32> %3 = "ttir.maximum"(%1, +// %arg1, %2) (tensor<512xf32>, tensor<512xf32>, tensor<512xf32>) -> +// tensor<512xf32> +// +// After folding: +// +// %0 = tensor.empty() : tensor<512xf32> +// %1 = "ttir.maximum"(%arg0, %arg1, %0) (tensor<1xf32>, tensor<512xf32>, +// tensor<512xf32>) -> tensor<512xf32> +//===----------------------------------------------------------------------===// + +class TTIRBroadcastFoldRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp op, + PatternRewriter &rewriter) const final { + + rewriter.replaceOp(op, op->getOperand(0)); + return success(); + } +}; + +class TTIRBroadcastFold + : public impl::TTIRBroadcastFoldBase { +public: + using impl::TTIRBroadcastFoldBase::TTIRBroadcastFoldBase; + + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { + signalPassFailure(); + return; + } + } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } +}; + +} // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTIR/Transforms/CMakeLists.txt b/lib/Dialect/TTIR/Transforms/CMakeLists.txt index f5fec45a8b..597c55e3ca 100644 --- a/lib/Dialect/TTIR/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTIR/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTTIRTransforms Allocate.cpp + Broadcast.cpp Constant.cpp Generic.cpp Layout.cpp diff --git a/lib/Dialect/TTIR/Transforms/Generic.cpp b/lib/Dialect/TTIR/Transforms/Generic.cpp index 005e12c079..3bf96f3cd6 100644 --- a/lib/Dialect/TTIR/Transforms/Generic.cpp +++ b/lib/Dialect/TTIR/Transforms/Generic.cpp @@ -257,7 +257,7 @@ class TTIRGenericRegionRewriter auto resEncoding = mlir::cast(op->getResult(0).getType()).getEncoding(); if (resEncoding) { - auto resLayout = mlir::cast(resEncoding); + auto resLayout = mlir::cast(resEncoding); gridAttr = resLayout.getGrid(); } @@ -339,7 +339,7 @@ struct TTIRGenericOperandsToMemrefRewriter auto matchingOperand = generic.getMatchingOperand(blockArgNumber); auto operandType = matchingOperand.getType(); - auto bufferLayout = mlir::cast( + auto bufferLayout = mlir::cast( mlir::cast(operandType).getEncoding()); auto bufferType = operandType; @@ -349,7 +349,7 @@ struct TTIRGenericOperandsToMemrefRewriter assert(static_cast(cbIndex) < generic.getCbs().size()); auto cb = generic.getCbs()[cbIndex]; auto cbType = cb.getType(); - auto cbLayout = mlir::cast( + auto cbLayout = mlir::cast( mlir::cast(cbType).getEncoding()); bufferLayout = cbLayout; bufferType = cbType; @@ -387,7 +387,7 @@ class TTIRGenericRegionMemrefTypeConverter : public TypeConverter { if (mlir::isa(encoding)) { return type; } - auto layout = mlir::cast(type.getEncoding()); + auto layout = mlir::cast(type.getEncoding()); auto buffer = BufferAttr::get(ctx, layout.getMemref(), BufferAccess::Alias); return RankedTensorType::get(buffer.getShape(), type.getElementType(), @@ -451,11 +451,11 @@ class TTIRGenericOpCBsRewriter : public OpRewritePattern { // Enforcing tiled layout as in kernel we always want to work with tiles. auto desiredElementType = rewriter.getType(ty.getElementType()); - auto desiredLayout = rewriter.getAttr( + auto desiredLayout = rewriter.getAttr( ty, MemorySpace::DeviceL1, generic.getGrid(), desiredElementType); auto operandTy = operand.getType(); - auto operandLayout = mlir::cast( + auto operandLayout = mlir::cast( mlir::cast(operandTy).getEncoding()); if (desiredLayout.getGrid() == operandLayout.getGrid()) { diff --git a/lib/Dialect/TTIR/Transforms/Layout.cpp b/lib/Dialect/TTIR/Transforms/Layout.cpp index d7eef6732d..c3ccbf1a44 100644 --- a/lib/Dialect/TTIR/Transforms/Layout.cpp +++ b/lib/Dialect/TTIR/Transforms/Layout.cpp @@ -38,20 +38,21 @@ class TTIRLayoutTensorTypeConverter : public TypeConverter { TTIRLayoutTensorTypeConverter(MLIRContext *ctx, MemorySpace initMemorySpace, GridAttr deviceGrid) { addConversion([](Type type) { return type; }); - addConversion([ctx, initMemorySpace, - deviceGrid](RankedTensorType type) -> Type { - auto layout = type.getEncoding(); - if (layout) { - return type; - } - std::int64_t deviceGridRank = deviceGrid.getShape().size(); - // Default to single core grid - auto tensorGrid = GridAttr::get(ctx, deviceGridRank); - // Default to initMemorySpace, the optimizer might decide otherwise - auto newLayout = LayoutAttr::get(ctx, type, initMemorySpace, tensorGrid); - return RankedTensorType::get(type.getShape(), type.getElementType(), - newLayout); - }); + addConversion( + [ctx, initMemorySpace, deviceGrid](RankedTensorType type) -> Type { + auto layout = type.getEncoding(); + if (layout) { + return type; + } + std::int64_t deviceGridRank = deviceGrid.getShape().size(); + // Default to single core grid + auto tensorGrid = GridAttr::get(ctx, deviceGridRank); + // Default to initMemorySpace, the optimizer might decide otherwise + auto newLayout = + MetalLayoutAttr::get(ctx, type, initMemorySpace, tensorGrid); + return RankedTensorType::get(type.getShape(), type.getElementType(), + newLayout); + }); } }; @@ -129,7 +130,7 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, TensorMemoryLayout desiredMemLayout, bool tiled) { auto ty = mlir::cast(input.getType()); - auto currLayout = mlir::cast(ty.getEncoding()); + auto currLayout = mlir::cast(ty.getEncoding()); auto currMemorySpace = currLayout.getMemorySpace(); auto currElementType = currLayout.getElementType(); auto currMemLayout = currLayout.getMemLayout(); @@ -142,9 +143,9 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, return std::nullopt; } - auto desiredLayout = - rewriter.getAttr(ty, desiredMemorySpace, currLayout.getGrid(), - desiredElementType, desiredMemLayout); + auto desiredLayout = rewriter.getAttr( + ty, desiredMemorySpace, currLayout.getGrid(), desiredElementType, + desiredMemLayout); tensor::EmptyOp existingEmpty = input.getDefiningOp(); if (existingEmpty) { @@ -343,7 +344,7 @@ class TTIRSplitCompoundLayoutRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; Value createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, - LayoutAttr desiredLayout) const { + MetalLayoutAttr desiredLayout) const { auto ty = mlir::cast(input.getType()); auto output = rewriter.create( loc, ty.getShape(), ty.getElementType(), desiredLayout); @@ -353,7 +354,7 @@ class TTIRSplitCompoundLayoutRewriter : public OpRewritePattern { } Value bounce(PatternRewriter &rewriter, ToLayoutOp op, - LayoutAttr bounceLayout) const { + MetalLayoutAttr bounceLayout) const { auto bounced = createToLayoutOp(rewriter, op.getLoc(), op.getInput(), bounceLayout); return rewriter.replaceOpWithNewOp( @@ -375,8 +376,8 @@ class TTIRSplitCompoundLayoutRewriter : public OpRewritePattern { auto inputType = mlir::cast(op.getInput().getType()); auto outputType = mlir::cast(op.getOutput().getType()); - auto inputLayout = mlir::cast(inputType.getEncoding()); - auto outputLayout = mlir::cast(outputType.getEncoding()); + auto inputLayout = mlir::cast(inputType.getEncoding()); + auto outputLayout = mlir::cast(outputType.getEncoding()); bool inputL1 = inputLayout.getMemorySpace() == MemorySpace::DeviceL1; bool outputL1 = outputLayout.getMemorySpace() == MemorySpace::DeviceL1; diff --git a/lib/Dialect/TTMetal/IR/TTMetalOps.cpp b/lib/Dialect/TTMetal/IR/TTMetalOps.cpp index 49baf51e01..7f78c1afcb 100644 --- a/lib/Dialect/TTMetal/IR/TTMetalOps.cpp +++ b/lib/Dialect/TTMetal/IR/TTMetalOps.cpp @@ -17,7 +17,7 @@ namespace mlir::tt::ttmetal { ::mlir::LogicalResult HostWriteOp::verify() { ::mlir::RankedTensorType outputTy = getOutput().getType(); auto outputLayout = - mlir::dyn_cast_or_null(outputTy.getEncoding()); + mlir::dyn_cast_or_null(outputTy.getEncoding()); if (not outputLayout) { return emitOpError("Input tensor missing layout attribute"); } @@ -30,7 +30,7 @@ ::mlir::LogicalResult HostWriteOp::verify() { ::mlir::LogicalResult HostReadOp::verify() { ::mlir::RankedTensorType outputTy = getOutput().getType(); auto outputLayout = - mlir::dyn_cast_or_null(outputTy.getEncoding()); + mlir::dyn_cast_or_null(outputTy.getEncoding()); if (not outputLayout) { return emitOpError("Input tensor missing layout attribute"); } @@ -41,7 +41,7 @@ ::mlir::LogicalResult HostReadOp::verify() { } ::mlir::LogicalResult AllocOp::verify() { - auto layout = mlir::dyn_cast_or_null( + auto layout = mlir::dyn_cast_or_null( getResult().getType().getEncoding()); if (not layout) { return emitOpError("Result type missing layout attribute"); @@ -76,7 +76,7 @@ ::mlir::LogicalResult AllocOp::verify() { ::mlir::LogicalResult DispatchOp::verify() { // Assert inputs/outputs device memspace for (auto operand : getOperands()) { - auto layout = mlir::dyn_cast_or_null( + auto layout = mlir::dyn_cast_or_null( mlir::cast(operand.getType()).getEncoding()); if (not layout) { return emitOpError("Input tensor missing layout attribute"); diff --git a/lib/Dialect/TTNN/Analysis/CMakeLists.txt b/lib/Dialect/TTNN/Analysis/CMakeLists.txt index 35d8f88ab3..4db2d78b9c 100644 --- a/lib/Dialect/TTNN/Analysis/CMakeLists.txt +++ b/lib/Dialect/TTNN/Analysis/CMakeLists.txt @@ -15,6 +15,6 @@ add_mlir_dialect_library(MLIRTTNNAnalysis MLIRTTNNPassesIncGen MLIRTTOpsIncGen - LINK_LIBS + LINK_LIBS PUBLIC MLIRScheduler ) diff --git a/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp index b83409d477..8d5f22bfc4 100644 --- a/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp @@ -217,9 +217,11 @@ void DFShardingPolicy::pickOpShardLayouts(ShardSolver &shardSolver, maxCoreUsage = accMaxCoreUsage[op][layoutIterator.index()]; selectedLayout = layoutIterator.get(); } else if (accMaxCoreUsage[op][layoutIterator.index()] == maxCoreUsage) { + assert(layoutIterator->getMemLayout() && + "TensorMemoryLayout is not set"); // If we have a tie, prefer layout that is not BlockSharded. // - if (layoutIterator->getMemLayout() != + if (layoutIterator->getMemLayout().getValue() != ttnn::TensorMemoryLayout::BlockSharded) { selectedLayout = layoutIterator.get(); } diff --git a/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp b/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp index c0b3ff102f..69a07af168 100644 --- a/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp @@ -3,19 +3,23 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h" -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h" #include "ttmlir/Scheduler/Scheduler.h" namespace mlir::tt::ttnn { -uint64_t getOpOutputLayoutUsage( - Operation *op, - llvm::DenseMap> &legalLayouts, - DeviceAttr &deviceAttr) { - TTNNLayoutAttr opLayout = legalLayouts.lookup(op).front(); - assert(opLayout.hasInterleavedL1TensorMemoryLayout()); +uint64_t getOpOutputL1Usage(Operation *op, TTNNLayoutAttr opLayout, + DeviceAttr &deviceAttr) { + // In case the opLayout is not in L1 memory space, L1 memory usage is 0. + // + if (opLayout.hasDRAMBufferType()) { + return 0; + } + // L1 memory usage of the ops without output tensors cannot be calculated. + // So far, this is only false for ttnn.get_device op. + // + assert(mlir::isa(op->getResult(0).getType())); llvm::ArrayRef opOutputTensorShape = mlir::cast(op->getResult(0).getType()).getShape(); @@ -24,132 +28,327 @@ uint64_t getOpOutputLayoutUsage( return opL1OutputUsage; } -void L1InterleavedPolicy::run() { - rootOp->walk([&](func::FuncOp func) { - DeviceAttr deviceAttr = getCurrentScopeDevice(func); - mlir::tt::scheduler::Scheduler scheduler(&func); - llvm::SmallVector scheduleableOps; - llvm::DenseMap selectedOpLayout; - Operation *currentOp = nullptr; +L1InterleavedPolicy::OpConfig L1InterleavedPolicy::getGreedyConfig( + Operation *baseOp, llvm::DenseMap &opsL1Usage) { + uint64_t numOfOps, bitIndex, currentMask; + uint64_t currentL1Usage, optimalL1Usage; + llvm::DenseMap optimalLayouts; + llvm::SmallVector optimalPrecedence; + + constexpr uint64_t maxNumOfOps = sizeof(numOfOps) * 8; + numOfOps = opsL1Usage.size(); + assert(numOfOps <= maxNumOfOps); + + optimalL1Usage = 0; + for (currentMask = 0; currentMask < (1 << numOfOps); currentMask++) { + std::bitset bitset(currentMask); + llvm::DenseMap currentLayouts; + llvm::SmallVector currentPrecedence, optimalL1Precedence, + L1Precedence; - // TODO(fbajraktari): Add algorithm description. Currently, the algorithm - // is the same as for DFSharding policy, but works only for L1 interleaved. + // Calculate the L1 usage of the current configuration. // - l1ChainConfigs->push_back(L1ChainConfig()); - while (scheduler.hasUnscheduledOps()) { - scheduleableOps = scheduler.getScheduleableOps(); + currentL1Usage = 0; + bitIndex = 0; + for (const auto &[op, l1Usage] : opsL1Usage) { + if (bitset[bitIndex]) { + // In case we have an operand with L1 interleaved layout, we need to + // figure out its schedule among the other operands with L1 interleaved + // layout. Therefore, we insert all of them into the L1Precedence where + // calculate the optimal L1Precedence and then concatenate it with the + // currentPrecedence. + // + currentL1Usage += l1Usage.outputL1Usage; + currentLayouts[op] = getL1InterleavedLayout(op); + + // Skip the baseOp. + // + if (baseOp != op) { + L1Precedence.push_back(op); + } + } else { + // It is optimal to first schedule all ops with DRAM output layout. + // Therefore, we can directly insert them into the + // currentOptimalPrecedence. + // + currentLayouts[op] = getDRAMLayout(op); - // Before starting a l1 chain, schedule layout/memory management ops - // first until they are exhausted from schedulable ops. + // Skip the baseOp. + // + if (baseOp != op) { + currentPrecedence.push_back(op); + } + } + bitIndex += 1; + } + + // Calculate the optimal L1Precedence. + // + bool isMaskLegal = false; + uint64_t minRequiredL1Usage = getAvailableL1CacheSize(); + + std::sort(L1Precedence.begin(), L1Precedence.end()); + do { + // Check if the current order of L1Precedence is legal. // - if (l1ChainConfigs->back().isEmpty()) { - for (auto *op : scheduleableOps) { - if (isa(op)) { - currentOp = op; - break; - } + bool isLegal = true; + uint64_t intermediateL1Usage = 0; + uint64_t intermediateRequiredL1Usage = 0; + for (Operation *op : L1Precedence) { + if (intermediateL1Usage + opsL1Usage[op].requiredL1Usage > + getAvailableL1CacheSize()) { + isLegal = false; + break; } + + intermediateRequiredL1Usage = + std::max(intermediateRequiredL1Usage, + intermediateL1Usage + opsL1Usage[op].requiredL1Usage); + intermediateL1Usage += opsL1Usage[op].outputL1Usage; } - if (currentOp == nullptr) { - currentOp = scheduleableOps[0]; + // Pick optimal L1Precedence among all legal L1Precedence. + // The one that requires the least amount of L1 cache overall is + // considered optimal. + // + if (isLegal && intermediateRequiredL1Usage < minRequiredL1Usage) { + isMaskLegal = true; + minRequiredL1Usage = intermediateRequiredL1Usage; + optimalL1Precedence = L1Precedence; } + } while (std::next_permutation(L1Precedence.begin(), L1Precedence.end())); + + if (isMaskLegal && optimalL1Usage < currentL1Usage && + currentL1Usage <= getAvailableL1CacheSize()) { - // Schedule currentOp. + // Append the legal L1Precedence to the currentPrecedence and therefore + // create a complete precedence for the baseOp and currentMask. // - scheduler.scheduleOp(currentOp); + currentPrecedence.insert(currentPrecedence.end(), + optimalL1Precedence.begin(), + optimalL1Precedence.end()); - // Skip starting sharding chain if currentOp is a memory management op. + // Update the optimal configuration. // - if (l1ChainConfigs->back().isEmpty() && isa(currentOp)) { - currentOp = nullptr; - continue; - } + optimalL1Usage = currentL1Usage; + optimalLayouts = std::move(currentLayouts); + optimalPrecedence = std::move(currentPrecedence); + } + } - if (scheduler.hasUnscheduledOps()) { - scheduleableOps = scheduler.getScheduleableOps(); + // Create the optimal config. + // + OpConfig optimalConfig; + optimalConfig.baseOp = baseOp; + optimalConfig.layouts = std::move(optimalLayouts); + optimalConfig.precedence = std::move(optimalPrecedence); - // Check if currentOp has a valid successor. + return optimalConfig; +} + +void L1InterleavedPolicy::run() { + for (Operation &funcOp : rootOp->getRegion(0).getOps()) { + func::FuncOp func = dyn_cast(funcOp); + DeviceAttr deviceAttr = getCurrentScopeDevice(func); + + // Start the policy. + // + llvm::DenseMap OpMemSpecMap; + mlir::tt::scheduler::Scheduler scheduler(&func); + llvm::SmallVector scheduleableOps; + + while (scheduler.hasUnscheduledOps()) { + scheduleableOps = scheduler.getScheduleableOps(); + + for (Operation *op : scheduleableOps) { + // Schedule the op. // - Operation *nextOp = nullptr; - for (auto *op : scheduleableOps) { - for (auto operand : op->getOperands()) { - if (operand.getDefiningOp() == currentOp) { - nextOp = op; - break; - } + scheduler.scheduleOp(op); + + // Find optimal configuration for the op. + // + llvm::DenseMap opsL1Usage; + llvm::SmallVector opsPrecedence; + + // Generate optimal configuration for the current op based on the + // outputs of its operands and its legal output layouts. + // + if (isAnalyzable(op)) { + + // Create the OpMemSpec. + // + OpMemSpec OpMemSpec; + assert(hasDRAMBufferType(op)); + OpMemSpec.layout = getDRAMLayout(op); + OpMemSpec.requiredL1Usage = 0; + OpMemSpecMap[op] = OpMemSpec; + + if (op->hasOneUse() && hasL1BufferType(op)) { + L1Usage l1Usage; + l1Usage.outputL1Usage = + getOpOutputL1Usage(op, getL1InterleavedLayout(op), deviceAttr); + l1Usage.requiredL1Usage = 0; + opsL1Usage[op] = l1Usage; } } - if (nextOp) { + for (auto operand : op->getOperands()) { + // Skip block arguments (%arg0, %arg1, ...) + // + if (::llvm::isa(operand)) { + continue; + } - // V1: Check that currentOp is not fork/join op. + Operation *operandOp = operand.getDefiningOp(); + + // Skip non-analyzable operands. // - bool validForL1Interleaved = - currentOp->hasOneUse() && - legalLayouts.lookup(currentOp).size() > 0 && - legalLayouts.lookup(nextOp).size() > 0; - - if (validForL1Interleaved) { - // Figure out this const based on exec data, but will be replaced - // with API. + if (isAnalyzable(operandOp)) { + TTNNLayoutAttr operandOpLayout = OpMemSpecMap[operandOp].layout; + + // Take into consideration only the operands with L1 interleaved + // memory space. // - constexpr float tensorL1UsageCap = 0.8; - uint64_t currentOpL1OutputUsage = - getOpOutputLayoutUsage(currentOp, legalLayouts, deviceAttr); - uint64_t nextOpL1OutputUsage = - getOpOutputLayoutUsage(nextOp, legalLayouts, deviceAttr); - bool l1UsageValid = (currentOpL1OutputUsage + nextOpL1OutputUsage) < - tensorL1UsageCap * usableL1CacheSize; - - if (l1UsageValid) { - selectedOpLayout[currentOp] = - legalLayouts.lookup(currentOp).front(); - - // Add currentOp to l1 chain config. - // - OpL1MemSpec shardSpec; - shardSpec.op = currentOp; - - // Hardcoded tensor split factor for now, until pipeline OP - // support is added. - // - shardSpec.tensorSplitFactor = 1; - l1ChainConfigs->back().addOpL1MemSpec(std::move(shardSpec)); - - // Update currentOp pointer. - // - currentOp = nextOp; - continue; + if (operandOpLayout.hasInterleavedL1TensorMemoryLayout()) { + L1Usage l1Usage; + l1Usage.outputL1Usage = + getOpOutputL1Usage(operandOp, operandOpLayout, deviceAttr); + l1Usage.requiredL1Usage = OpMemSpecMap[operandOp].requiredL1Usage; + opsL1Usage[operandOp] = l1Usage; + } + // In case the operand has DRAM layout, we can insert it into the + // precedence directly. If the op is analyzable, it means that it + // is definitely schedulable. + // + else { + opsPrecedence.push_back(operandOp); + } + } + // In case the operand is not analyzable, i.e. there are no legal + // layouts for this operand, we can insert it into the precedence + // directly if it is schedulable since it does not use DRAM nor L1 + // memory. + // + else { + if (scheduler.isTTShedulableOp(operandOp)) { + opsPrecedence.push_back(operandOp); } } } - currentOp = nullptr; - if (!l1ChainConfigs->back().isEmpty()) { - l1ChainConfigs->back().build(); - l1ChainConfigs->push_back(L1ChainConfig()); + // Greedily find the optimal configuration. + // + OpConfig optimalConfig = getGreedyConfig(op, opsL1Usage); + for (const auto &[op, layout] : optimalConfig.layouts) { + OpMemSpecMap[op].layout = layout; + } + + // Override op's precedence. + // + opsPrecedence.insert(opsPrecedence.end(), + optimalConfig.precedence.begin(), + optimalConfig.precedence.end()); + precedenceMap[op] = std::move(opsPrecedence); + + // Update op's requiredL1Usage if the op is analyzable. + // + if (isAnalyzable(op)) { + uint64_t intermediateRequiredL1Usage = 0; + uint64_t intermediateL1Usage = 0; + for (auto operand : op->getOperands()) { + // Skip block arguments (%arg0, %arg1, ...) + // + if (::llvm::isa(operand)) { + continue; + } + + Operation *operandOp = operand.getDefiningOp(); + + // Skip non-analyzable operands. + // + if (isAnalyzable(operandOp)) { + intermediateRequiredL1Usage = + std::max(intermediateRequiredL1Usage, + intermediateL1Usage + + OpMemSpecMap[operandOp].requiredL1Usage); + intermediateL1Usage += getOpOutputL1Usage( + operandOp, OpMemSpecMap[operandOp].layout, deviceAttr); + } + } + OpMemSpecMap[op].requiredL1Usage = std::max( + intermediateRequiredL1Usage, + intermediateL1Usage + + getOpOutputL1Usage(op, OpMemSpecMap[op].layout, deviceAttr)); } } } - if (l1ChainConfigs->back().isEmpty()) { - l1ChainConfigs->pop_back(); - } + // Construct the schedule. + // + constructSchedule(func); - // Schedule + // Build, Resolve and Complete the L1 chain. + // This implementation is only here unitl we are able to merge + // L1ChainConfigs. + // TODO(fbajraktari): Fix this hack. // - (*schedule)[func] = scheduler.getSchedule(); + l1ChainConfigs->push_back(L1ChainConfig()); + llvm::DenseMap selectedOpLayout; + for (auto &OpMemSpec : OpMemSpecMap) { + OpL1MemSpec opL1MemSpec; + opL1MemSpec.op = OpMemSpec.first; + opL1MemSpec.tensorSplitFactor = 1; + selectedOpLayout[OpMemSpec.first] = OpMemSpec.second.layout; + l1ChainConfigs->back().addOpL1MemSpec(opL1MemSpec); + } + l1ChainConfigs->back().build(); + l1ChainConfigs->back().resolve(); + std::unordered_set memReconfigEdges; + l1ChainConfigs->back().complete(selectedOpLayout, memReconfigEdges); + } +} - // Resolve l1 chain configs. +bool L1InterleavedPolicy::isAnalyzable(Operation *op) { + // Skip operations that are not analyzed by the LegalLayoutAnalysis. + // + if (legalLayouts.count(op) > 0) { + // Skip operations that are filterd out by the MemoryLayoutAnalysis. // - for (auto &l1ChainConfig : *l1ChainConfigs) { - l1ChainConfig.resolve(); + return legalLayouts[op].size() > 0; + } + return false; +} - std::unordered_set memReconfigEdges; - l1ChainConfig.complete(selectedOpLayout, memReconfigEdges); - } - }); +bool L1InterleavedPolicy::hasDRAMBufferType(Operation *op) { + return std::find_if(legalLayouts[op].begin(), legalLayouts[op].end(), + [](TTNNLayoutAttr layout) { + return layout.hasDRAMBufferType(); + }) != legalLayouts[op].end(); +} + +TTNNLayoutAttr L1InterleavedPolicy::getDRAMLayout(Operation *op) { + assert(hasDRAMBufferType(op)); + auto dramLayoutIter = std::find_if( + legalLayouts[op].begin(), legalLayouts[op].end(), + [](TTNNLayoutAttr layout) { return layout.hasDRAMBufferType(); }); + return *dramLayoutIter; +} + +bool L1InterleavedPolicy::hasL1BufferType(Operation *op) { + return std::find_if(legalLayouts[op].begin(), legalLayouts[op].end(), + [](TTNNLayoutAttr layout) { + return layout.hasInterleavedL1TensorMemoryLayout(); + }) != legalLayouts[op].end(); +} + +TTNNLayoutAttr L1InterleavedPolicy::getL1InterleavedLayout(Operation *op) { + assert(hasL1BufferType(op)); + auto l1InterleaveLayoutIter = + std::find_if(legalLayouts[op].begin(), legalLayouts[op].end(), + [](TTNNLayoutAttr layout) { + return layout.hasInterleavedL1TensorMemoryLayout(); + }); + return *l1InterleaveLayoutIter; } } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp b/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp index 5ef3f47312..3f4ef25ab2 100644 --- a/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp @@ -124,37 +124,40 @@ bool LegalLayoutAnalysis::applyOverrides() { elementType = TileType::get(op->getContext(), elementType); } - analysisResult.push_back( - TTNNLayoutAttr::get(op->getContext(), tensorShape, elementType, - layoutOverride.bufferType.value(), grid, - layoutOverride.tensorMemoryLayout.value())); + analysisResult.push_back(TTNNLayoutAttr::get( + op->getContext(), tensorShape, elementType, + layoutOverride.bufferType.value(), grid, + TensorMemoryLayoutAttr::get(op->getContext(), + layoutOverride.tensorMemoryLayout.value()))); return true; } bool incompatibleWithOverride( const TTNNLayoutAttr &layout, - const std::optional &override) { - if (not override.has_value()) { + const std::optional &layoutOverride) { + if (not layoutOverride.has_value()) { return false; } - if (override->grid.has_value()) { - if (layout.getGrid().getShape()[0] != override->grid.value()[0] or - layout.getGrid().getShape()[1] != override->grid.value()[1]) { + if (layoutOverride->grid.has_value()) { + if (layout.getGrid().getShape()[0] != layoutOverride->grid.value()[0] || + layout.getGrid().getShape()[1] != layoutOverride->grid.value()[1]) { return true; } } - if (override->bufferType.has_value() and - layout.getBufferType() != override->bufferType.value()) { + if (layoutOverride->bufferType.has_value() && + layout.getBufferType() != layoutOverride->bufferType.value()) { return true; } - if (override->tensorMemoryLayout.has_value() and - layout.getMemLayout() != override->tensorMemoryLayout.value()) { + if (layoutOverride->tensorMemoryLayout.has_value() && + layout.getMemLayout().getValue() != + layoutOverride->tensorMemoryLayout.value()) { return true; } - if (override->memoryLayout.has_value() and - layout.isTiled() != (override->memoryLayout.value() == Layout::Tile)) { + if (layoutOverride->memoryLayout.has_value() && + layout.isTiled() != + (layoutOverride->memoryLayout.value() == Layout::Tile)) { return true; } return false; @@ -166,6 +169,14 @@ void LegalLayoutAnalysis::analysisImplementation() { return; } + if (!isa(op->getResult(0).getType())) { + return; + } + + if (llvm::isa(op)) { + return; + } + // Get output tensor type. RankedTensorType tensorType = mlir::cast(op->getResult(0).getType()); @@ -199,7 +210,7 @@ void LegalLayoutAnalysis::analysisImplementation() { std::vector shardedResults; bool rowMajorAllowed = analysisInput.rowMajorEnabled; - if (override.has_value() and override->memoryLayout.has_value() and + if (override.has_value() && override->memoryLayout.has_value() && override->memoryLayout.value() == Layout::RowMajor) { // Force allow row major if override is set. rowMajorAllowed = true; @@ -207,18 +218,22 @@ void LegalLayoutAnalysis::analysisImplementation() { // Generate both TILE and ROW_MAJOR layouts. for (Type elementType : {scalarElementType, tileElementType}) { - if (not rowMajorAllowed and elementType == scalarElementType) { + if (not rowMajorAllowed && elementType == scalarElementType) { continue; } // DRAM analysisResult.push_back(TTNNLayoutAttr::get( op->getContext(), tensorShape, elementType, BufferType::DRAM, - analysisInput.maxGrid, TensorMemoryLayout::Interleaved)); + analysisInput.maxGrid, + TensorMemoryLayoutAttr::get(op->getContext(), + TensorMemoryLayout::Interleaved))); // L1 Interleaved (same as above). analysisResult.push_back(TTNNLayoutAttr::get( op->getContext(), tensorShape, elementType, BufferType::L1, - analysisInput.maxGrid, TensorMemoryLayout::Interleaved)); + analysisInput.maxGrid, + TensorMemoryLayoutAttr::get(op->getContext(), + TensorMemoryLayout::Interleaved))); // L1 Sharded TTNNLayoutAttr shardedBase = @@ -268,7 +283,7 @@ void LegalLayoutAnalysis::analysisImplementation() { shardedResults.erase( std::remove_if(shardedResults.begin(), shardedResults.end(), [this](TTNNLayoutAttr layout) { - return !tensorShapeCompatibleWithShard(op, layout) or + return !tensorShapeCompatibleWithShard(op, layout) || !mockIsOutputTensorLegalForOp(op, layout); }), shardedResults.end()); diff --git a/lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp b/lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp index a89c5842b9..f3db4ed7bf 100644 --- a/lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp @@ -5,6 +5,7 @@ #include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h" #include "ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h" #include "ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" namespace mlir::tt::ttnn { @@ -35,14 +36,15 @@ filterShardedOnly(const llvm::DenseMap> } llvm::DenseMap> -filterL1InterleavedOnly( +filterDRAMAndL1Interleaved( const llvm::DenseMap> &legalLayouts) { llvm::DenseMap> l1InterleavedLayouts; for (const auto &opLayouts : legalLayouts) { std::vector opL1InterleavedLayouts; for (const auto &layout : opLayouts.second) { - if (layout.hasInterleavedL1TensorMemoryLayout()) { + if (layout.hasDRAMBufferType() || + layout.hasInterleavedL1TensorMemoryLayout()) { opL1InterleavedLayouts.push_back(layout); } } @@ -68,7 +70,8 @@ void MemoryLayoutAnalysis::analysisImplementation() { } case MemoryLayoutAnalysisPolicyType::L1Interleaved: { L1InterleavedPolicy l1InterleavedPolicy( - op, l1ChainConfigs, filterL1InterleavedOnly(analysisInput.legalLayouts), + op, l1ChainConfigs, + filterDRAMAndL1Interleaved(analysisInput.legalLayouts), analysisResult.schedule, analysisInput.usableL1CacheSize); l1InterleavedPolicy.run(); break; diff --git a/lib/Dialect/TTNN/IR/CMakeLists.txt b/lib/Dialect/TTNN/IR/CMakeLists.txt index 1620e96b5c..2fb004e0f3 100644 --- a/lib/Dialect/TTNN/IR/CMakeLists.txt +++ b/lib/Dialect/TTNN/IR/CMakeLists.txt @@ -4,6 +4,8 @@ add_mlir_dialect_library(MLIRTTNNDialect TTNNOps.cpp TTNNOpModelInterface.cpp TTNNOpsTypes.cpp + TTNNWorkaroundInterface.cpp + TTNNWorkarounds.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir @@ -11,10 +13,13 @@ add_mlir_dialect_library(MLIRTTNNDialect DEPENDS MLIRTTNNOpsIncGen MLIRTTOpsIncGen + MLIRTTNNWorkaroundInterfaceIncGen + TTNNOpModelLib LINK_LIBS PUBLIC TTMLIRTTNNUtils MLIRSCFToEmitC MLIRLinalgDialect MLIRMLProgramDialect + TTNNOpModelLib ) diff --git a/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp b/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp index 9079a60194..344a4a4831 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp @@ -5,6 +5,9 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.cpp.inc" +#include "ttmlir/OpModel/TTNN/TTNNOpModel.h" + +#include #include namespace mlir::tt::ttnn { @@ -22,14 +25,16 @@ size_t ReluOp::getOpPerfCycles(const std::vector &input_layouts, std::tuple ReluOp::getOpL1Usage(const std::vector &input_layouts, const TTNNLayoutAttr &output_layout) { - // TODO(mbezulj) wire to tt-metal once we have API - return std::make_tuple(1024, 2048, 1024); + assert(input_layouts.size() == 1); + return op_model::ttnn::ReluOpInterface::getOpL1Usage(input_layouts[0], + output_layout); } bool ReluOp::isOpLegal(const std::vector &input_layouts, const TTNNLayoutAttr &output_layout) { - // TODO(mbezulj) wire to tt-metal once we have API - return true; + assert(input_layouts.size() == 1); + return op_model::ttnn::ReluOpInterface::isLegal(input_layouts[0], + output_layout); } } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 8550b8796d..cca75a7b26 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -42,7 +42,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::ClampOp::verify() { const RankedTensorType outputTensorType = mlir::cast(outputs.front().getType()); - if (inputTensorType != outputTensorType) { + if (inputTensorType.getShape() != outputTensorType.getShape()) { return emitOpError("input and output must have same shape."); } @@ -140,6 +140,32 @@ ::mlir::LogicalResult mlir::tt::ttnn::MaxPool2dOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ArangeOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::tt::ttnn::ArangeOp::verify() { + + if (getStep() == 0) { + return emitOpError("Step cannot be zero."); + } + + int64_t numValues = (getEnd() - getStart()) / getStep(); + + if (numValues <= 0) { + return emitOpError("Invalid range: start=") + << getStart() << ", end=" << getEnd() << ", step=" << getStep(); + } + + std::vector expectedShape = {1, 1, 1, numValues}; + if (getType().getShape().vec() != expectedShape) { + return emitOpError() << "Output tensor shape must be " << expectedShape + << ", but got " << getType().getShape(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // EmptyOp //===----------------------------------------------------------------------===// @@ -164,25 +190,12 @@ ::mlir::LogicalResult mlir::tt::ttnn::EmptyOp::verify() { // DataType and Layout // - mlir::MemRefType memref = layoutAttr.getMemref(); - Type elementType = memref.getElementType(); if (getLayout().has_value()) { - ttnn::Layout ttnnLayoutEnum; - if (llvm::isa(elementType)) { - ttnnLayoutEnum = ttnn::Layout::Tile; - } else { - ttnnLayoutEnum = ttnn::Layout::RowMajor; - } + ttnn::Layout ttnnLayoutEnum = layoutAttr.getLayout(); assert(ttnnLayoutEnum == getLayoutAttr().getValue()); } if (getDtype().has_value()) { - tt::DataType dtype; - if (llvm::isa(elementType)) { - auto tileType = mlir::cast(elementType); - dtype = tileType.getDataType(); - } else { - dtype = elementTypeToDataType(elementType); - } + tt::DataType dtype = layoutAttr.getDataType(); assert(dtype == getDtype()); } @@ -192,10 +205,11 @@ ::mlir::LogicalResult mlir::tt::ttnn::EmptyOp::verify() { // if (getMemoryConfig().has_value()) { ttnn::BufferType bufferType = layoutAttr.getBufferType(); - ttnn::TensorMemoryLayout tensorMemoryLayout = layoutAttr.getMemLayout(); + ttnn::TensorMemoryLayoutAttr tensorMemoryLayoutAttr = + layoutAttr.getMemLayout(); assert(bufferType == getMemoryConfig()->getBufferType().getValue()); - assert(tensorMemoryLayout == - getMemoryConfig()->getTensorMemoryLayout().getValue()); + assert(tensorMemoryLayoutAttr == + getMemoryConfig()->getTensorMemoryLayout()); } // // ============================== @@ -289,11 +303,6 @@ ::mlir::LogicalResult mlir::tt::ttnn::ReshapeOp::verify() { return emitOpError("Shape attribute must be non-empty"); } - // Check that the shape attribute has at most 5 elements - if (shape_size > 5) { - return emitOpError("Shape attribute must have at most 5 elements"); - } - // Cardinality of the input and output tensors must be the same if (inputType.getNumElements() != outputType.getNumElements()) { return emitOpError( @@ -534,9 +543,10 @@ ::mlir::LogicalResult mlir::tt::ttnn::EmbeddingOp::verify() { //===----------------------------------------------------------------------===// // Utility methods -static bool isValidDeviceLayout(TensorMemoryLayout layout) { - return layout == TensorMemoryLayout::Interleaved || - isShardedMemoryLayout(layout); +static bool isValidDeviceLayout(TensorMemoryLayoutAttr memLayoutAttr) { + return memLayoutAttr && + (memLayoutAttr.getValue() == TensorMemoryLayout::Interleaved || + isShardedMemoryLayout(memLayoutAttr.getValue())); } // ToMemoryConfigOp verification @@ -554,11 +564,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() { return emitOpError("Output tensor type missing layout attribute"); } BufferType outputBufferType = outputLayout.getBufferType(); - TensorMemoryLayout outputMemoryLayout = outputLayout.getMemLayout(); - if (isSystemBufferType(outputBufferType) && - outputMemoryLayout != TensorMemoryLayout::None) { - return emitOpError("System memory space only supports undef memory layout"); - } + TensorMemoryLayoutAttr outputMemoryLayout = outputLayout.getMemLayout(); if (isDeviceBufferType(outputBufferType) && !isValidDeviceLayout(outputMemoryLayout)) { @@ -567,7 +573,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() { } if (outputBufferType == BufferType::DRAM && - outputMemoryLayout != TensorMemoryLayout::Interleaved) { + outputMemoryLayout.getValue() != TensorMemoryLayout::Interleaved) { return emitOpError( "Device DRAM memory space only supports interleaved memory layout"); } @@ -581,7 +587,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() { if (shardShape.size() != 2) { return emitOpError("Shard shape must be 2D"); } - if (outputMemoryLayout == TensorMemoryLayout::BlockSharded) { + if (outputMemoryLayout.getValue() == TensorMemoryLayout::BlockSharded) { // TTNN tiles are (32, 32), shard shape must evenly divide the tile shape if (shardShape[0] % TILE_HEIGHT != 0 or shardShape[1] % TILE_WIDTH != 0) { return emitOpError( @@ -592,6 +598,158 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// LinearOp +//===----------------------------------------------------------------------===// + +// LinearOp verification +::mlir::LogicalResult mlir::tt::ttnn::LinearOp::verify() { + ::mlir::RankedTensorType inputAType = getA().getType(); + ::mlir::RankedTensorType inputBType = getB().getType(); + std::optional<::mlir::RankedTensorType> biasType = + getBias() ? std::make_optional(getBias().getType()) : std::nullopt; + ::mlir::RankedTensorType outputType = getOutput().getType(); + + llvm::ArrayRef outputShape = outputType.getShape(); + llvm::SmallVector inputAShape(inputAType.getShape()); + llvm::SmallVector inputBShape(inputBType.getShape()); + + // Verify that the input A is at least 1D tensor. + if (inputAType.getRank() < 1) { + return emitOpError("Input A must be at least a 1D tensor"); + } + + // Verify that the input B is at least 1D tensor. + if (inputBType.getRank() < 1) { + return emitOpError("Input B must be at least a 1D tensor"); + } + + // If input A is a vector (1D tensor), 1 is prepended to its dimension for the + // purpose of the matrix multiplication. After the matrix multiplication, the + // prepended dimension is removed. + if (inputAType.getRank() == 1) { + inputAShape.insert(inputAShape.begin(), 1); + } + + // If input B is a vector (1D tensor), a 1 is appended to its dimension for + // the purpose of the matrix-vector product and removed afterwards. + if (inputBType.getRank() == 1) { + inputBShape.push_back(1); + } + + // Verify that the input A and input B has matching inner dimensions. + if (inputAShape[inputAShape.size() - 1] != + inputBShape[inputBShape.size() - 2]) { + return emitOpError( + "Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) + + ") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) + + ") must have matching inner dimensions"); + } + + llvm::SmallVector expectedOutputShape; + // Verify that the batch dimensions are broadcast compatible and construct the + // expected output shape. + if (inputAShape.size() > 2 || inputBShape.size() > 2) { + llvm::SmallVector inputABatchDims, inputBBatchDims; + + if (inputAShape.size() > 2) { + inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(), + inputAShape.end() - 2); + } + + if (inputBShape.size() > 2) { + inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(), + inputBShape.end() - 2); + } + + // Verify that the batch dimensions of input A and B are broadcast + // compatible. + llvm::SmallVector broadcastedShape; + if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims, + broadcastedShape)) { + + return emitOpError("Batch dimensions of input A(" + + ttmlir::utils::join(inputABatchDims, ",") + + ") and B(" + + ttmlir::utils::join(inputBBatchDims, ",") + + ") are not broadcast compatible"); + } + + // Insert the broadcasted batch dimensions in the expected output shape. + expectedOutputShape.insert(expectedOutputShape.begin(), + broadcastedShape.begin(), + broadcastedShape.end()); + } + + // Insert the input A and B inner dimensions in expected output shape. + // Consider the case where input A and B are vectors. In that case, + // the dimension 1 is ommited from the output shape. + if (inputAType.getRank() > 1) { + expectedOutputShape.push_back(inputAShape[inputAShape.size() - 2]); + } + + if (inputBType.getRank() > 1) { + expectedOutputShape.push_back(inputBShape[inputBShape.size() - 1]); + } + + if (biasType) { + // Verify that the input bias is at least 1D tensor. + if (biasType.value().getRank() < 1) { + return emitOpError("Bias must be at least a 1D tensor"); + } + + llvm::SmallVector biasShape(biasType.value().getShape()); + + // Verify that the dimensions of the matmul of A and B are broadcast + // compatible with input bias. + llvm::SmallVector matmulShape = expectedOutputShape; + if (!OpTrait::util::getBroadcastedShape(matmulShape, biasShape, + expectedOutputShape)) { + return emitOpError("Bias shape(" + ttmlir::utils::join(biasShape, ",") + + ") is not broadcast compatible with the matmul output " + "shape(" + + ttmlir::utils::join(matmulShape, ",") + ")"); + } + } + + // Check the case of a vector-vector product. At this moment we don't support + // scalars in IR, hence check that the output is at least 1D tensor of size 1. + if (expectedOutputShape.size() == 0) { + if (outputType.getRank() < 1) { + return emitOpError("Scalar output is not supported, output must be at " + "least a 1D tensor"); + } + + if (outputType.getRank() > 1 || outputType.getShape()[0] != 1) { + return emitOpError("Scalar output must be a 1D tensor of size 1"); + } + + return success(); + } + + // Verify that the output shape dimension count is correct. + if (outputShape.size() != expectedOutputShape.size()) { + return emitOpError("Output shape rank(" + + std::to_string(outputShape.size()) + + ") must match the expected output shape rank(" + + std::to_string(expectedOutputShape.size()) + ")"); + } + + // Verify each dim of the output shape. + for (size_t i = 0; i < outputShape.size(); i++) { + if (outputShape[i] != expectedOutputShape[i]) { + return emitOpError( + "Output shape dimension[" + std::to_string(i) + "](" + + std::to_string(outputShape[i]) + + ") doesn't match the expected output shape dimension[" + + std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) + + ")"); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // MatmulOp //===----------------------------------------------------------------------===// @@ -785,6 +943,10 @@ ::mlir::LogicalResult mlir::tt::ttnn::SoftmaxOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AllGatherOp +//===----------------------------------------------------------------------===// + ::mlir::LogicalResult AllGatherOp::verify() { ::mlir::RankedTensorType inputType = getInput().getType(); int32_t dim = getDim(); @@ -796,9 +958,134 @@ ::mlir::LogicalResult AllGatherOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ReduceScatterOp +//===----------------------------------------------------------------------===// + ::mlir::LogicalResult ReduceScatterOp::verify() { // TODO(gfengTT) return success(); } +//===----------------------------------------------------------------------===// +// UpdateCacheOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult UpdateCacheOp::verify() { + if (getBatchOffset() != 0) { + return emitOpError( + "Only single-batch is supported. Batch offset must be 0"); + } + + const ::mlir::RankedTensorType cacheType = getCache().getType(); + const ::mlir::RankedTensorType inputType = getInput().getType(); + + const DataType cacheDataType = + elementTypeToDataType(cacheType.getElementType()); + const DataType inputDataType = + elementTypeToDataType(inputType.getElementType()); + + if (cacheDataType != inputDataType) { + return emitOpError( + "Cache and input tensors must have the same dtype. " + "Got cache dtype = " + + DataTypeEnumToString(cacheDataType) + + ", input dtype = " + DataTypeEnumToString(inputDataType)); + } + + if (cacheType.getRank() != 4) { + return emitOpError("Cache tensor must be a 4D tensor"); + } + + if (inputType.getRank() != 4) { + return emitOpError("Input tensor must be a 4D tensor"); + } + + if (inputType.getShape()[2] != 1) { + return emitOpError("Input tensor requires that dim 2 have size 1, got " + "input dim 2 size = " + + std::to_string(inputType.getShape()[2])); + } + + if (cacheType.getShape()[0] != inputType.getShape()[0] || + cacheType.getShape()[1] != inputType.getShape()[1] || + cacheType.getShape()[3] != inputType.getShape()[3]) { + return emitOpError("Cache tensor shape must match input tensor shape on " + "all dimensions except dim 2. Got cache shape (" + + std::to_string(cacheType.getShape()[0]) + ", " + + std::to_string(cacheType.getShape()[1]) + ", " + + std::to_string(cacheType.getShape()[2]) + ", " + + std::to_string(cacheType.getShape()[3]) + + "), input shape ()" + + std::to_string(inputType.getShape()[0]) + "x" + + std::to_string(inputType.getShape()[1]) + "x" + + std::to_string(inputType.getShape()[2]) + "x" + + std::to_string(inputType.getShape()[3]) + ")"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// FillCacheOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult FillCacheOp::verify() { + if (getBatchOffset() != 0) { + return emitOpError( + "Only single-batch is supported. Batch offset must be 0"); + } + + const ::mlir::RankedTensorType cacheType = getCache().getType(); + const ::mlir::RankedTensorType inputType = getInput().getType(); + + const DataType cacheDataType = + elementTypeToDataType(cacheType.getElementType()); + const DataType inputDataType = + elementTypeToDataType(inputType.getElementType()); + + if (cacheDataType != inputDataType) { + return emitOpError( + "Cache and input tensors must have the same dtype. " + "Got cache dtype = " + + DataTypeEnumToString(cacheDataType) + + ", input dtype = " + DataTypeEnumToString(inputDataType)); + } + + if (cacheType.getRank() != 4) { + return emitOpError("Cache tensor must be a 4D tensor"); + } + + if (inputType.getRank() != 4) { + return emitOpError("Input tensor must be a 4D tensor"); + } + + if (inputType.getShape()[2] > cacheType.getShape()[2]) { + return emitOpError( + "Input tensor requires that dim 2 have a size which is less than or " + "equal to the size of dim 2 of the cache tensor. Got cache dim 2 size " + "= " + + std::to_string(cacheType.getShape()[2]) + + ", input dim 2 size = " + std::to_string(inputType.getShape()[2])); + } + + if (cacheType.getShape()[0] != inputType.getShape()[0] || + cacheType.getShape()[1] != inputType.getShape()[1] || + cacheType.getShape()[3] != inputType.getShape()[3]) { + return emitOpError("Cache tensor shape must match input tensor shape on " + "all dimensions except dim 2. Got cache shape (" + + std::to_string(cacheType.getShape()[0]) + ", " + + std::to_string(cacheType.getShape()[1]) + ", " + + std::to_string(cacheType.getShape()[2]) + ", " + + std::to_string(cacheType.getShape()[3]) + + "), input shape (" + + std::to_string(inputType.getShape()[0]) + ", " + + std::to_string(inputType.getShape()[1]) + ", " + + std::to_string(inputType.getShape()[2]) + ", " + + std::to_string(inputType.getShape()[3]) + ")"); + } + + return success(); +} + } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp index b18a6b1d64..43c5984ed9 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp @@ -14,45 +14,58 @@ using namespace mlir::tt::ttnn; -// Check if tensor is on host -inline bool isSystemBufferType(BufferType bufferType) { - return bufferType == BufferType::SystemMemory; +// Check if the tensor is tiled +bool TTNNLayoutAttr::isTiled() const { + return ::mlir::isa<::mlir::tt::TileType>(getElementType()); } -// Check if the tensor is on device -inline bool isDeviceBufferType(BufferType bufferType) { - return bufferType == BufferType::DRAM || bufferType == BufferType::L1; +// Get layout of the tensor (RowMajor/Tile) +Layout TTNNLayoutAttr::getLayout() const { + return isTiled() ? Layout::Tile : Layout::RowMajor; } -// Check if tensor is in L1 memory -inline bool isL1BufferType(BufferType bufferType) { - return bufferType == BufferType::L1; +// Get optinoal memory layout +std::optional TTNNLayoutAttr::getMemLayoutOpt() const { + return getMemLayout() ? std::make_optional(getMemLayout().getValue()) + : std::nullopt; } -// Check if the tensor is tiled -bool TTNNLayoutAttr::isTiled() const { - return ::mlir::isa<::mlir::tt::TileType>(getElementType()); +// Check if the tensor memory buffer type is L1 +bool TTNNLayoutAttr::hasL1BufferType() const { + return isL1BufferType(getBufferType()); +} + +// Check if the tensor memory buffer type is DRAM +bool TTNNLayoutAttr::hasDRAMBufferType() const { + return isDRAMBufferType(getBufferType()); } // Check if the tensor memory layout is sharded bool TTNNLayoutAttr::hasShardedTensorMemoryLayout() const { - return (getMemLayout() == TensorMemoryLayout::HeightSharded || - getMemLayout() == TensorMemoryLayout::WidthSharded || - getMemLayout() == TensorMemoryLayout::BlockSharded); + return isDeviceBufferType() && + (getMemLayout().getValue() == TensorMemoryLayout::HeightSharded || + getMemLayout().getValue() == TensorMemoryLayout::WidthSharded || + getMemLayout().getValue() == TensorMemoryLayout::BlockSharded); } // Check if the tensor memory layout is sharded in L1 memory bool TTNNLayoutAttr::hasShardedL1TensorMemoryLayout() const { - return isL1BufferType(getBufferType()) && - (getMemLayout() == TensorMemoryLayout::HeightSharded || - getMemLayout() == TensorMemoryLayout::WidthSharded || - getMemLayout() == TensorMemoryLayout::BlockSharded); + return hasL1BufferType() && + (getMemLayout().getValue() == TensorMemoryLayout::HeightSharded || + getMemLayout().getValue() == TensorMemoryLayout::WidthSharded || + getMemLayout().getValue() == TensorMemoryLayout::BlockSharded); } // Check if the tensor memory layout is interleaved and in L1 memory bool TTNNLayoutAttr::hasInterleavedL1TensorMemoryLayout() const { - return isL1BufferType(getBufferType()) && - (getMemLayout() == TensorMemoryLayout::Interleaved); + return hasL1BufferType() && + (getMemLayout().getValue() == TensorMemoryLayout::Interleaved); +} + +// Check if the tensor memory layout is interleaved and in DRAM memory +bool TTNNLayoutAttr::hasInterleavedDRAMTensorMemoryLayout() const { + return hasDRAMBufferType() && + (getMemLayout().getValue() == TensorMemoryLayout::Interleaved); } // Get stride given tensor logical shape @@ -129,19 +142,19 @@ mlir::Type TTNNLayoutAttr::getScalarElementType() const { return elementType; } -// Extract data type from the memref. Example: -// memref<2x2xf32> -> f32 -// memref<2x2x!tt.tile<32x32xf32>> -> f32 -mlir::tt::DataType TTNNLayoutAttr::getDataTypeFromMemRef() const { +// Get scalar element type. +// Example: memref<2x2xf32> -> f32 +// Example: memref<2x2x!tt.tile<32x32xf32>> -> f32 +// +// return The scalar element type. +mlir::tt::DataType TTNNLayoutAttr::getDataType() const { Type elementType = getElementType(); - DataType dtype = DataType::Float32; - if (llvm::isa(elementType)) { + if (isTiled()) { TileType tileType = mlir::cast(elementType); - dtype = tileType.getDataType(); - } else { - dtype = elementTypeToDataType(elementType); + return tileType.getDataType(); } - return dtype; + + return elementTypeToDataType(elementType); } // Gets the size of shard in bytes @@ -149,10 +162,10 @@ mlir::tt::DataType TTNNLayoutAttr::getDataTypeFromMemRef() const { // This function returns the size of the shard in bytes. // Size is calculated by multiplying shard shape with element size. // -// /return The size of the shard in bytes. +// return The size of the shard in bytes. uint64_t TTNNLayoutAttr::getElementSizeBytes() const { mlir::Type elementType = getElementType(); - if (mlir::isa(elementType)) { + if (isTiled()) { TileType tileType = mlir::cast(elementType); return tileType.getSizeBytes(); } @@ -161,21 +174,31 @@ uint64_t TTNNLayoutAttr::getElementSizeBytes() const { // Get shard shape // -// This function returns the shape of the shard. If element type is TileType -// and convertTileToScalar is true, then the shape is converted to scalar shape. -// Example: (convertToScalar = true) memref<2x2x!tt.tile<32x32xf32>> -> {64, 64} -// Example: (convertToScalar = false) memref<2x2x!tt.tile<32x32xf32>> -> {2, 2} -// Example: memref<128x128xf32> -> {128, 128} +// Return the shape of the shard. +// Example: memref<2x2x!tt.tile<32x32xf32>> -> { 2, 2 } +// Example: memref<128x128xf32> -> { 128, 128 } +// Example: memref<2x3!tt.tile<32x32xf32>> -> { 2, 3 } // -// /param convertTileToScalar If true, convert tile shape to scalar shape. -// /return The shape of the shard. -llvm::SmallVector -TTNNLayoutAttr::getShardShape(bool convertTileToScalar) const { +// return The shape of the shard. +llvm::SmallVector TTNNLayoutAttr::getShardShape() const { + return SmallVector(getMemref().getShape()); +} + +// Get scalar shard shape +// +// If the element type is TileType, this function returns the scalar shape of +// the shard. +// Example: memref<2x2x!tt.tile<32x32xf32>> -> { 64, 64 } +// Example: memref<128x128xf32> -> { 128, 128 } +// Example: memref<2x3!tt.tile<32x32xf32>> -> { 64, 96 } +// +// return The scalar shape of the shard. +llvm::SmallVector TTNNLayoutAttr::getScalarShardShape() const { SmallVector shardShape(getMemref().getShape()); - Type elementType = getElementType(); - if (mlir::isa(elementType) && convertTileToScalar) { - return mlir::cast(elementType).getScalarShape(shardShape); + if (isTiled()) { + return mlir::cast(getElementType()).getScalarShape(shardShape); } + return shardShape; } @@ -188,8 +211,8 @@ TTNNLayoutAttr::getShardShape(bool convertTileToScalar) const { // d2) and tile shape (32, 32) The result is (90, 10) which is then divided by // tile shape (32, 32) -> (3, 1) // -// /param tensorShape The shape of the tensor -// /return The size of the tensor in tiles. +// param tensorShape The shape of the tensor +// return The size of the tensor in tiles. llvm::SmallVector TTNNLayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { assert(isTiled() && "Expected a tiled layout"); @@ -224,10 +247,9 @@ TTNNLayoutAttr::getTiledShape(llvm::ArrayRef tensorShape) const { // Element size for TileType is tile width * tile height * sizeof(element). // For scalar types, element size is sizeof(element). // -// /return The size of the shard in bytes. +// return The size of the shard in bytes. uint64_t TTNNLayoutAttr::getShardSizeInBytes() const { - MemRefType ty = getMemref(); - ArrayRef shape = ty.getShape(); + SmallVector shape = getShardShape(); uint64_t size = getElementSizeBytes(); return std::accumulate(shape.begin(), shape.end(), size, std::multiplies()); @@ -238,7 +260,7 @@ uint64_t TTNNLayoutAttr::getShardSizeInBytes() const { // This function returns a new identity affine map // with the same number of dimensions as the linear map. // -// /return The new identity affine map. +// return The new identity affine map. mlir::AffineMap TTNNLayoutAttr::getIdentityTileLinearMap() const { assert(isTiled() && "Expected a tiled layout"); @@ -251,12 +273,11 @@ mlir::AffineMap TTNNLayoutAttr::getIdentityTileLinearMap() const { // This function takes a physical memory map and replaces the symbols with the // shard shape // -// /param physicalMemoryMap The physical memory map (d0, d1)[s0, s1] -// /return New memory map with symbols replaced with shard shape. +// param physicalMemoryMap The physical memory map (d0, d1)[s0, s1] +// return New memory map with symbols replaced with shard shape. mlir::AffineMap TTNNLayoutAttr::replaceMemoryMapSymbolsWithShardShape( AffineMap physicalMemoryMap) const { - mlir::SmallVector shardShape = - getShardShape(false /*convertTileToScalar*/); + mlir::SmallVector shardShape = getShardShape(); assert(physicalMemoryMap.getNumSymbols() == shardShape.size() && "Physical memory map must have same number of symbols as logical " "shard rank"); @@ -299,11 +320,11 @@ int64_t TTNNLayoutAttr::getTensorSizeInBytes(ArrayRef tensorShape, // This function creates a new TTNNLayoutAttr with the given parameters. // The element type, buffer type and memory layout are preserved. // -// /param context The MLIR context. -// /param tensorShape The shape of the tensor (i.e 6x10x10) -// /param grid The grid where the tensor will be placed (i.e 2x3) -// /param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) -// /return The constructed TTNNLayoutAttr +// param context The MLIR context. +// param tensorShape The shape of the tensor (i.e 6x10x10) +// param grid The grid where the tensor will be placed (i.e 2x3) +// param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) +// return The constructed TTNNLayoutAttr TTNNLayoutAttr TTNNLayoutAttr::withGrid( ::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals) { @@ -317,10 +338,10 @@ TTNNLayoutAttr TTNNLayoutAttr::withGrid( // The shape of the tensor, buffer type, element type and memory layout are // preserved. // -// /param context The MLIR context. -// /param grid The grid where the tensor will be placed. -// /param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) -// /return The constructed TTNNLayoutAttr +// param context The MLIR context. +// param grid The grid where the tensor will be placed. +// param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) +// return The constructed TTNNLayoutAttr TTNNLayoutAttr TTNNLayoutAttr::withGrid( ::mlir::MLIRContext *context, RankedTensorType ty, GridAttr grid, ArrayRef> collapseIntervals) { @@ -334,14 +355,14 @@ TTNNLayoutAttr TTNNLayoutAttr::withGrid( // This function creates a deep copy of the current TTNNLayoutAttr and // replaces the element type with the given one. // -// /param context The MLIR context. -// /param elementType The new element type. -// /return The new TTNNLayoutAttr with the given element type. +// param context The MLIR context. +// param elementType The new element type. +// return The new TTNNLayoutAttr with the given element type. TTNNLayoutAttr TTNNLayoutAttr::withElementType(::mlir::MLIRContext *context, Type elementType) { return TTNNLayoutAttr::get( context, getLinear(), getGrid(), - buildMemRef(context, getShardShape(), + buildMemRef(context, getScalarShardShape(), elementType, getBufferType()), getMemLayout()); } @@ -351,14 +372,14 @@ TTNNLayoutAttr TTNNLayoutAttr::withElementType(::mlir::MLIRContext *context, // This function creates a deep copy of the current TTNNLayoutAttr and // replaces the memory space with the given one. // -// /param context The MLIR context. -// /param memorySpace The new memory space. -// /return The new TTNNLayoutAttr with the given memory space. +// param context The MLIR context. +// param memorySpace The new memory space. +// return The new TTNNLayoutAttr with the given memory space. TTNNLayoutAttr TTNNLayoutAttr::withBufferType(::mlir::MLIRContext *context, BufferType memorySpace) { return TTNNLayoutAttr::get( context, getLinear(), getGrid(), - buildMemRef(context, getShardShape(), + buildMemRef(context, getScalarShardShape(), getElementType(), memorySpace), getMemLayout()); } @@ -368,16 +389,33 @@ TTNNLayoutAttr TTNNLayoutAttr::withBufferType(::mlir::MLIRContext *context, // This function creates a deep copy of the current TTNNLayoutAttr and // replaces the memory layout with the given one. // -// /param context The MLIR context. -// /param memLayout The new memory layout. -// /return The new TTNNLayoutAttr with the given memory layout. -TTNNLayoutAttr TTNNLayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, - TensorMemoryLayout memLayout) { +// param context The MLIR context. +// param memLayoutAttr The new memory layout. +// return The new TTNNLayoutAttr with the given memory layout. +TTNNLayoutAttr +TTNNLayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, + TensorMemoryLayoutAttr memLayoutAttr) { return TTNNLayoutAttr::get( context, getLinear(), getGrid(), buildMemRef( - context, getShardShape(), getElementType(), getBufferType()), - memLayout); + context, getScalarShardShape(), getElementType(), getBufferType()), + memLayoutAttr); +} + +// Construct a new TTNNLayoutAttr +// +// This function creates a deep copy of the current TTNNLayoutAttr and +// replaces the memory layout with the given one. +// +// param context The MLIR context. +// param memLayout The new memory layout. +// return The new TTNNLayoutAttr with the given memory layout. +TTNNLayoutAttr TTNNLayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, + TensorMemoryLayout memLayout) { + + TensorMemoryLayoutAttr memLayoutAttr = + TensorMemoryLayoutAttr::get(context, memLayout); + return withMemoryLayout(context, memLayoutAttr); } // Construct a new TTNNLayoutAttr @@ -385,9 +423,9 @@ TTNNLayoutAttr TTNNLayoutAttr::withMemoryLayout(::mlir::MLIRContext *context, // This function creates a deep copy of the current TTNNLayoutAttr and // replaces shard shape with the given one. // -// /param context The MLIR context. -// /param shardShape The new shard shape. -// /return The new TTNNLayoutAttr with the given shard shape. +// param context The MLIR context. +// param shardShape The new shard shape. +// return The new TTNNLayoutAttr with the given shard shape. TTNNLayoutAttr TTNNLayoutAttr::withShardShape(::mlir::MLIRContext *context, llvm::SmallVector shardShape) { @@ -402,18 +440,18 @@ TTNNLayoutAttr::withShardShape(::mlir::MLIRContext *context, // // This function constructs a new TTNNLayoutAttr with the given parameters. // -// /param context The MLIR context. -// /param tensorShape The shape of the tensor (i.e 6x10x10) -// /param elementType The type of the element i.e TileType/FloatType/IntegerType -// /param bufferType The type of the buffer -// /param grid The grid where the tensor will be placed (i.e 2x3) -// /param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) -// /param memLayout The memory layout of the tensor -// /return The constructed TTNNLayoutAttr +// param context The MLIR context. +// param tensorShape The shape of the tensor (i.e 6x10x10) +// param elementType The type of the element i.e TileType/FloatType/IntegerType +// param bufferType The type of the buffer +// param grid The grid where the tensor will be placed (i.e 2x3) +// param collapseIntervals The intervals to collapse (i.e. {{0, -1}}) +// param memLayout The memory layout of the tensor +// return The constructed TTNNLayoutAttr TTNNLayoutAttr TTNNLayoutAttr::get( ::mlir::MLIRContext *context, ArrayRef tensorShape, Type elementType, BufferType bufferType, GridAttr grid, - TensorMemoryLayout memLayout, + TensorMemoryLayoutAttr memLayoutAttr, ArrayRef> collapseIntervals) { // Construct a new affine map which will be used to map from logical // space to physical space @@ -426,5 +464,35 @@ TTNNLayoutAttr TTNNLayoutAttr::get( // Build memref type with the given parameters MemRefType memRefType = buildMemRef( context, shardShape, elementType, bufferType); - return get(context, linear, grid, memRefType, memLayout); + return get(context, linear, grid, memRefType, memLayoutAttr); +} + +// Construct a new MemoryConfig +// +// This function creates a deep copy of the current MemoryConfigAttr and +// replaces the buffer type with the given one. +// +// param context The MLIR context. +// param buffer type The new buffer type. +// return The new MemoryConfigAttr with the given buffer type. +MemoryConfigAttr MemoryConfigAttr::withBufferType(::mlir::MLIRContext *context, + BufferType bufferType) { + return MemoryConfigAttr::get(context, + BufferTypeAttr::get(context, bufferType), + getShardSpec(), getTensorMemoryLayout()); +} + +// Construct a new MemoryConfig +// +// This function creates a deep copy of the current MemoryConfig and +// replaces the memory layout with the given one. +// +// param context The MLIR context. +// param memLayout The new memory layout. +// return The new MemoryConfig with the given memory layout. +MemoryConfigAttr +MemoryConfigAttr::withMemoryLayout(::mlir::MLIRContext *context, + TensorMemoryLayout memLayout) { + return MemoryConfigAttr::get(context, getBufferType(), getShardSpec(), + TensorMemoryLayoutAttr::get(context, memLayout)); } diff --git a/lib/Dialect/TTNN/IR/TTNNWorkaroundInterface.cpp b/lib/Dialect/TTNN/IR/TTNNWorkaroundInterface.cpp new file mode 100644 index 0000000000..88d49a4545 --- /dev/null +++ b/lib/Dialect/TTNN/IR/TTNNWorkaroundInterface.cpp @@ -0,0 +1,90 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h" + +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h" +#include "ttmlir/Utils.h" + +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include + +namespace mlir::tt::ttnn::wa { +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.cpp.inc" + +// Verifier function for TTNN Workaround Interface. +mlir::LogicalResult verifyTTNNWorkaroundInterface(mlir::Operation *op) { + + // Verify that the number of input and output operand workarounds is the same + // as the number of tensor operands and tensor results. + size_t cntTensorInputs = + llvm::count_if(op->getOperands(), ttmlir::utils::isRankedTensor); + size_t cntTensorResults = + llvm::count_if(op->getResults(), ttmlir::utils::isRankedTensor); + + TTNNWorkaroundInterface workaroundOp = + mlir::cast(op); + + TTNNOperandsWorkarounds workarounds = workaroundOp.getOperandsWorkarounds(); + + if (workarounds.getInputOperandWorkarounds().size() != cntTensorInputs) { + return op->emitOpError() + << "Number of input operand workarounds " + << workarounds.getInputOperandWorkarounds().size() + << " does not match the number of tensor inputs " << cntTensorInputs; + } + + if (workarounds.getOutputOperandWorkarounds().size() != cntTensorResults) { + return op->emitOpError() << "Number of output operand workarounds " + << " does not match the number of tensor results " + << cntTensorResults; + } + + // For DPS ops, verify that the output workaround is the same as the input + // init workaround. + if (mlir::isa(op)) { + DestinationStyleOpInterface dpsOp = + mlir::cast(op); + + // Go through all the operands and for each DPS init operand, check if the + // output workaround is the same. + int dpsDestinationIndex = 0; + for (int64_t i = 0; i < op->getNumOperands(); i++) { + OpOperand &operand = op->getOpOperand(i); + + // Skip if the output result isn't a tensor. + if (!ttmlir::utils::isRankedTensor(operand.get())) { + dpsDestinationIndex++; + continue; + } + + // Skip if the operand is not a DPS init. + if (!dpsOp.isDpsInit(&operand)) { + dpsDestinationIndex++; + continue; + } + + // Get the tied output result for the DPS destination operand. + OpResult tiedOutputResult = dpsOp.getTiedOpResult(&operand); + + // Check if the output workaround is the same as the input DPS destination + // workaround. + if (workarounds.getOutputOperandWorkarounds()[tiedOutputResult + .getResultNumber()] != + workarounds.getInputOperandWorkarounds()[dpsDestinationIndex]) { + return op->emitOpError() + << "DPS output workaround does not match " + "the input DPS destination operand workaround " + << tiedOutputResult.getResultNumber() << " and " + << dpsDestinationIndex; + } + + dpsDestinationIndex++; + } + } + + // All checks passed, return success. + return mlir::success(); +} +} // namespace mlir::tt::ttnn::wa diff --git a/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp new file mode 100644 index 0000000000..0dd7eaaafd --- /dev/null +++ b/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h" + +#include "ttmlir/Utils.h" + +#include "llvm/ADT/SmallVector.h" + +namespace mlir::tt::ttnn::wa { + +// Operand workarounds factory method +TTNNOperandWorkarounds +TTNNOperandWorkarounds::createEmptyTTNNOperandWorkarounds() { + return TTNNOperandWorkarounds(); +} + +// Operands workarounds factory method +TTNNOperandsWorkarounds +TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(int inputSize, + int outputSize) { + llvm::SmallVector inputOperandWorkarounds( + inputSize, TTNNOperandWorkarounds::createEmptyTTNNOperandWorkarounds()); + llvm::SmallVector outputOperandWorkarounds( + outputSize, TTNNOperandWorkarounds::createEmptyTTNNOperandWorkarounds()); + return TTNNOperandsWorkarounds(inputOperandWorkarounds, + outputOperandWorkarounds); +} + +// Method to apply tensor workarounds. If the workaround is present, it +// applies the workaround, and returns both the target workaround argument and +// a flag indicating whether the workaround was applied. +WorkaroundResult applyWorkarounds(const TTNNOperandWorkarounds &workaround, + const TTNNLayoutAttr &inputLayoutAttr) { + WorkaroundResult result; + result.targetTensorLayoutResult.first = + workaround.tensorLayoutWorkaround.value_or(inputLayoutAttr.getLayout()); + result.targetTensorLayoutResult.second = + result.targetTensorLayoutResult.first != inputLayoutAttr.getLayout(); + + result.targetTensorBufferTypeResult.first = + workaround.tensorBufferTypeWorkaround.value_or( + inputLayoutAttr.getBufferType()); + result.targetTensorBufferTypeResult.second = + result.targetTensorBufferTypeResult.first != + inputLayoutAttr.getBufferType(); + + // If the tensor memory layout workaround is present, apply it. + // Otherwise, return the input tensor memory layout, which may be + // nullopt if tensor is on host. + result.targetTensorMemoryLayoutResult.first = + workaround.tensorMemoryLayoutWorkaround.has_value() + ? workaround.tensorMemoryLayoutWorkaround + : inputLayoutAttr.getMemLayoutOpt(); + result.targetTensorMemoryLayoutResult.second = + result.targetTensorMemoryLayoutResult.first != + inputLayoutAttr.getMemLayoutOpt(); + + return result; +} + +// Operands workarounds factory method. +TTNNOperandsWorkarounds +TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(Operation *op) { + size_t tensorInputs = + llvm::count_if(op->getOperands(), ttmlir::utils::isRankedTensor); + size_t tensorResults = + llvm::count_if(op->getResults(), ttmlir::utils::isRankedTensor); + + return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds( + tensorInputs, tensorResults); +} +} // namespace mlir::tt::ttnn::wa diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index a49292ed93..a125942f48 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -63,6 +63,14 @@ void createTTNNPipelineLoweringPasses( pm.addPass(mlir::createRemoveDeadValuesPass()); } +// Create a pass to workaround issues in the TTNN dialect. +void createTTNNPipelineWorkaroundPass( + OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { + if (options.workaroundPassEnabled) { + pm.addPass(createTTNNWorkarounds()); + } +} + void createTTNNPipelineLayoutDecompositionPass( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { pm.addPass(createTTNNDecomposeLayouts()); @@ -108,10 +116,24 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm, createTTNNPipelineDeallocPass(pm, *optionsStruct); } +void createTTNNPipelineTTIRBroadcastFoldPass( + OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { + pm.addPass(mlir::tt::ttir::createTTIRBroadcastFold()); +} + +void createTTNNPipelineTTIRBroadcastFoldPassFromString(OpPassManager &pm, + std::string options) { + auto optionsStruct = + TTIRToTTNNBackendPipelineOptions::createFromString(options); + createTTNNPipelineTTIRBroadcastFoldPass(pm, *optionsStruct); +} + void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { createTTNNPipelineTTIRPasses(pm, options); + createTTNNPipelineTTIRBroadcastFoldPass(pm, options); createTTNNPipelineLoweringPasses(pm, options); + createTTNNPipelineWorkaroundPass(pm, options); createTTNNPipelineAnalysisPasses(pm, options); createTTNNPipelineLayoutDecompositionPass(pm, options); createTTNNPipelineDeallocPass(pm, options); diff --git a/lib/Dialect/TTNN/Transforms/CMakeLists.txt b/lib/Dialect/TTNN/Transforms/CMakeLists.txt index 3f075148b0..fd21e03d0c 100644 --- a/lib/Dialect/TTNN/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTNN/Transforms/CMakeLists.txt @@ -1,8 +1,9 @@ add_mlir_dialect_library(MLIRTTNNTransforms - TTNNLayout.cpp - Passes.cpp Optimizer.cpp + Passes.cpp + TTNNLayout.cpp TTNNToCpp.cpp + TTNNWorkarounds.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir diff --git a/lib/Dialect/TTNN/Transforms/Optimizer.cpp b/lib/Dialect/TTNN/Transforms/Optimizer.cpp index ad259773fd..9ada2dbb5d 100644 --- a/lib/Dialect/TTNN/Transforms/Optimizer.cpp +++ b/lib/Dialect/TTNN/Transforms/Optimizer.cpp @@ -176,6 +176,10 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { return; } + if (llvm::isa(op)) { + return; + } + RankedTensorType tensorType = mlir::cast(op->getResult(0).getType()); LegalLayoutAnalysis legalLayoutAnalysis = @@ -276,13 +280,14 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { // if (isa(op)) { BufferType bufferType = layoutAttr.getBufferType(); - TensorMemoryLayout tensorMemoryLayout = layoutAttr.getMemLayout(); + TensorMemoryLayoutAttr tensorMemoryLayoutAttr = + layoutAttr.getMemLayout(); op->getOperands().back().setType(newTensorType); EmptyOp emptyOp = mlir::cast(op->getOperands().back().getDefiningOp()); - emptyOp.setDtype(layoutAttr.getDataTypeFromMemRef()); + emptyOp.setDtype(layoutAttr.getDataType()); if (layoutAttr.isTiled()) { emptyOp.setLayout(ttnn::Layout::Tile); } else { @@ -290,13 +295,12 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { } emptyOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get( op->getContext(), - TensorMemoryLayoutAttr::get(op->getContext(), - tensorMemoryLayout), BufferTypeAttr::get(op->getContext(), bufferType), ShardSpecAttr::get( op->getContext(), ShapeAttr::get(op->getContext(), - layoutAttr.getMemref().getShape())))); + layoutAttr.getMemref().getShape())), + tensorMemoryLayoutAttr)); } // TODO(mtopalovic): Temp workaround for generic ToLayoutOp. Allign // MemoryConfigAttr with layout attribute of its output tensor. This @@ -305,19 +309,19 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { // else if (isa(op)) { BufferType bufferType = layoutAttr.getBufferType(); - TensorMemoryLayout tensorMemoryLayout = layoutAttr.getMemLayout(); + TensorMemoryLayoutAttr tensorMemoryLayoutAttr = + layoutAttr.getMemLayout(); // Update the device op with the new tensor type. // ttnn::ToLayoutOp toLayoutOp = llvm::cast(op); toLayoutOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get( op->getContext(), - ttnn::TensorMemoryLayoutAttr::get(op->getContext(), - tensorMemoryLayout), ttnn::BufferTypeAttr::get(op->getContext(), bufferType), ttnn::ShardSpecAttr::get( op->getContext(), ttnn::ShapeAttr::get(op->getContext(), - layoutAttr.getMemref().getShape())))); + layoutAttr.getMemref().getShape())), + tensorMemoryLayoutAttr)); } } }); @@ -453,18 +457,18 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { consumerOpOutputLayout.getGrid())); BufferType outputBufferType = consumerOpOutputLayout.getBufferType(); - TensorMemoryLayout outputTensorMemoryLayout = + TensorMemoryLayoutAttr outputTensorMemoryLayoutAttr = consumerOpOutputLayout.getMemLayout(); - MemRefType outputMemref = consumerOpOutputLayout.getMemref(); + llvm::SmallVector shardShape = + consumerOpOutputLayout.getShardShape(); MemoryConfigAttr outputMemConfigAttr = MemoryConfigAttr::get( consumerOp->getContext(), - TensorMemoryLayoutAttr::get(consumerOp->getContext(), - outputTensorMemoryLayout), BufferTypeAttr::get(consumerOp->getContext(), outputBufferType), - ShardSpecAttr::get(consumerOp->getContext(), - ShapeAttr::get(consumerOp->getContext(), - outputMemref.getShape()))); + ShardSpecAttr::get( + consumerOp->getContext(), + ShapeAttr::get(consumerOp->getContext(), shardShape)), + outputTensorMemoryLayoutAttr); // If producerOp is a toLayoutOp, adjust its output layout(update // inplace) to reflect consumerOp's output layout. If producerOp is not a @@ -478,10 +482,9 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { } else { OpBuilder builder(consumerOp); - DataTypeAttr outputDataType = - DataTypeAttr::get(consumerOp->getContext(), - utils::getDataTypeFromMemRef(outputMemref)); - Layout outputLayoutEnum = utils::getLayoutFromMemRef(outputMemref); + DataTypeAttr outputDataType = DataTypeAttr::get( + consumerOp->getContext(), consumerOpOutputLayout.getDataType()); + Layout outputLayoutEnum = consumerOpOutputLayout.getLayout(); LayoutAttr outputLayout = LayoutAttr::get(consumerOp->getContext(), outputLayoutEnum); Operation *memoryReconfigOp = builder.create( diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index 79bfeb4049..c842c4075b 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -3,15 +3,27 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttmlir/Dialect/TTNN/Transforms/Passes.h" + +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" +#include "ttmlir/Dialect/TTNN/Utils/Utils.h" + #include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" -#include "ttmlir/Dialect/TTNN/Utils/Utils.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" namespace mlir::tt::ttnn { #define GEN_PASS_DEF_TTNNDEALLOCATE #define GEN_PASS_DEF_TTNNDECOMPOSELAYOUTS +#define GEN_PASS_DEF_TTNNCREATEINPUTGENERATORS #include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc" class TTNNDeallocate : public impl::TTNNDeallocateBase { @@ -130,16 +142,15 @@ class TTNNDecomposeLayouts ttnn::BufferType bufferType; ttnn::Layout layoutEnum; DataType dataType; - ttnn::TensorMemoryLayout tensorMemoryLayout; + ttnn::TensorMemoryLayoutAttr tensorMemoryLayout; llvm::ArrayRef shardShape; ttnn::MemoryConfigAttr createMemoryConfigAttr(MLIRContext *context) const { return ttnn::MemoryConfigAttr::get( - context, - ttnn::TensorMemoryLayoutAttr::get(context, tensorMemoryLayout), - ttnn::BufferTypeAttr::get(context, bufferType), + context, ttnn::BufferTypeAttr::get(context, bufferType), ttnn::ShardSpecAttr::get(context, - ttnn::ShapeAttr::get(context, shardShape))); + ttnn::ShapeAttr::get(context, shardShape)), + tensorMemoryLayout); } bool isOnHost() const { @@ -198,24 +209,12 @@ class TTNNDecomposeLayouts } }; - ttnn::Layout getLayoutFromMemRef(mlir::MemRefType memref) const { - ttnn::Layout ttnnLayoutEnum = ttnn::Layout::RowMajor; - Type elementType = memref.getElementType(); - if (llvm::isa(elementType)) { - ttnnLayoutEnum = ttnn::Layout::Tile; - } else { - ttnnLayoutEnum = ttnn::Layout::RowMajor; - } - return ttnnLayoutEnum; - } - std::pair getInputOutputLayouts(ttnn::ToLayoutOp op) const { LayoutInfo input, output; auto inputLayoutAttr = mlir::cast(op.getInput().getType().getEncoding()); - auto inputMemref = inputLayoutAttr.getMemref(); assert(op.getMemoryConfig().has_value()); MemoryConfigAttr outputMemoryConfig = op.getMemoryConfig().value(); @@ -223,18 +222,17 @@ class TTNNDecomposeLayouts input.bufferType = inputLayoutAttr.getBufferType(); output.bufferType = outputMemoryConfig.getBufferType().getValue(); - input.layoutEnum = getLayoutFromMemRef(inputMemref); + input.layoutEnum = inputLayoutAttr.getLayout(); output.layoutEnum = op.getLayout(); - input.dataType = ttnn::utils::getDataTypeFromMemRef(inputMemref); + input.dataType = inputLayoutAttr.getDataType(); assert(op.getDtype().has_value()); output.dataType = op.getDtype().value(); input.tensorMemoryLayout = inputLayoutAttr.getMemLayout(); - output.tensorMemoryLayout = - outputMemoryConfig.getTensorMemoryLayout().getValue(); + output.tensorMemoryLayout = outputMemoryConfig.getTensorMemoryLayout(); - input.shardShape = inputMemref.getShape(); + input.shardShape = inputLayoutAttr.getShardShape(); output.shardShape = outputMemoryConfig.getShardShapeArray(); return {input, output}; } @@ -263,8 +261,8 @@ class TTNNDecomposeLayouts // device tensor if (not opsToCreate.createToDeviceOp and output.isOnDevice()) { opsToCreate.createToMemoryConfigOp = - (input.tensorMemoryLayout != output.tensorMemoryLayout) and - (output.tensorMemoryLayout != ttnn::TensorMemoryLayout::None); + output.tensorMemoryLayout && + (input.tensorMemoryLayout != output.tensorMemoryLayout); opsToCreate.createToMemoryConfigOp |= (input.bufferType == ttnn::BufferType::DRAM and output.bufferType == ttnn::BufferType::L1) or @@ -887,4 +885,183 @@ class TTNNDecomposeLayouts } }; +class TTNNCreateInputGenerators + : public impl::TTNNCreateInputGeneratorsBase { + +public: + using impl::TTNNCreateInputGeneratorsBase< + TTNNCreateInputGenerators>::TTNNCreateInputGeneratorsBase; + + void runOnOperation() final { + ModuleOp module = getOperation(); + IRRewriter rewriter(&getContext()); + + // Ensure that the module has a single region and a single block within that + // region + assert(module->getRegions().size() == 1); + assert(module->getRegion(0).getBlocks().size() == 1); + + // Get the first block of the region at index 0 + // + Block *firstBlock = module.getBody(0); + + // Find all the func.func ops in the module + // + SmallVector forwardFuncOps; + for (mlir::Operation &op : firstBlock->getOperations()) { + if (mlir::func::FuncOp funcOp = dyn_cast(op)) { + + // Skip functions that are called elsewhere in the IR + // + // This will skip utility functions that are used by other functions, + // only top-level "forward" functions should be considered + // + if (!funcOp->getUses().empty()) { + continue; + } + + forwardFuncOps.push_back(funcOp); + } + } + + // Iterate over all the func ops and add input tensor generator functions + // + for (mlir::func::FuncOp forwardFuncOp : forwardFuncOps) { + // Get all the input tensors for the current forward func + // + llvm::SmallVector inputTensors; + for (auto input : forwardFuncOp.getFunctionType().getInputs()) { + inputTensors.push_back(llvm::cast(input)); + } + + // Create a new function that will generate the input tensors + // + std::string inputGenFuncName = + "createInputsFor_" + forwardFuncOp.getName().str(); + + // Create function type + // + mlir::TypeRange returnTypeRange = + mlir::TypeRange(forwardFuncOp.getFunctionType().getInputs()); + FunctionType functionType = + mlir::FunctionType::get(&getContext(), {}, returnTypeRange); + + // Set insertion point to end of first block + // + rewriter.setInsertionPointToEnd(firstBlock); + + // Create the function + // + func::FuncOp inputGenFuncOp = rewriter.create( + module->getLoc(), inputGenFuncName, functionType); + + // Add a Block to func op and set insertion point to the beginning of the + // Block + // + ::mlir::Block *currFnBlock = inputGenFuncOp.addEntryBlock(); + rewriter.setInsertionPointToStart(currFnBlock); + + // Create the input tensors + // + SmallVector generatedTensors; + for (Type tensorType : returnTypeRange) { + assert(llvm::isa(tensorType)); + + RankedTensorType tensor = + llvm::cast(tensorType); + + // Get the layout attribute + // + ttnn::TTNNLayoutAttr layoutAttr = + mlir::cast(tensor.getEncoding()); + + // Get the shape of the tensor, tensor layout, and data type + // + ShapeAttr shapeAttr = + ttnn::ShapeAttr::get(&getContext(), tensor.getShape()); + ttnn::LayoutAttr tensorLayoutAttr = + ttnn::LayoutAttr::get(&getContext(), layoutAttr.getLayout()); + DataTypeAttr dTypeAttr = + DataTypeAttr::get(&getContext(), layoutAttr.getDataType()); + + // Create a new tensor + // + // TODO(svuckovic): Move from ttnn::EmptyOp to ttnn::OnesOp once #1476 + // lands + // + mlir::Value tensorValue = rewriter.create( + forwardFuncOp->getLoc(), tensorType, nullptr, shapeAttr, dTypeAttr, + tensorLayoutAttr, nullptr); + + generatedTensors.push_back(tensorValue); + } + + // Return the generated tensors + // + rewriter.create(forwardFuncOp->getLoc(), + generatedTensors); + } + + // Create a main function to call input generators and forward funcs + // + { + // Create a new function that will generate the input tensors + // + std::string mainFuncName = "main"; + + // Create function type + // + mlir::TypeRange returnTypeRange = mlir::TypeRange(rewriter.getI32Type()); + FunctionType functionType = + mlir::FunctionType::get(&getContext(), {}, returnTypeRange); + + // Set insertion point to end of first block + // + rewriter.setInsertionPointToEnd(firstBlock); + + // Create the function + // + func::FuncOp mainFuncOp = rewriter.create( + module->getLoc(), mainFuncName, functionType); + + ::mlir::Block *currFnBlock = mainFuncOp.addEntryBlock(); + + // Set insertion point to the beginning of the block + // + rewriter.setInsertionPointToStart(currFnBlock); + + // Call the input generators + // + for (mlir::func::FuncOp forwardFuncOp : forwardFuncOps) { + std::string inputGenFuncName = + "createInputsFor_" + forwardFuncOp.getName().str(); + + // Get the input generator function + // + mlir::func::FuncOp inputGenFuncOp = + module.lookupSymbol(inputGenFuncName); + + // Call the input generator function + // + func::CallOp createdTensors = rewriter.create( + forwardFuncOp->getLoc(), inputGenFuncOp, ValueRange()); + + rewriter.create(forwardFuncOp->getLoc(), + forwardFuncOp, + createdTensors->getResults()); + } + + // Return 0 + // + // func::ReturnOp requires a Value to be returned, which means that an SSA + // needs to be returned, hence create a constant 0 via arith::ConstantOp + // + Value constantZero = rewriter.create( + rewriter.getUnknownLoc(), rewriter.getI32Type(), + rewriter.getI32IntegerAttr(0)); + rewriter.create(mainFuncOp->getLoc(), constantZero); + } + } +}; + } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp index eebfdc13f3..712e12ad08 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp @@ -75,7 +75,7 @@ class TTNNLayoutTensorTypeConverter : public TypeConverter { TTNNLayoutAttr newLayout = TTNNLayoutAttr::get( ctx, type.getShape(), type.getElementType(), g_defaultMemorySpaceHost, - tensorGrid, TensorMemoryLayout::None, collapseDimsRef); + tensorGrid, nullptr /* memLayoutAttr */, collapseDimsRef); return RankedTensorType::get(type.getShape(), type.getElementType(), newLayout); }); @@ -154,23 +154,22 @@ class TTNNLayoutTensorTypeRewriter : public RewritePattern { static std::optional createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, BufferType desiredBufferType, - TensorMemoryLayout desiredMemLayout, bool tiled) { + TensorMemoryLayoutAttr desiredMemLayoutAttr, bool tiled) { // Get type RankedTensorType ty = mlir::cast(input.getType()); // Get ttnn layout from the type - TTNNLayoutAttr tensorConfig = mlir::cast(ty.getEncoding()); + TTNNLayoutAttr ttnnLayoutAttr = mlir::cast(ty.getEncoding()); // Get buffer type (i.e DRAM/L1 etc) - BufferType currBufferType = tensorConfig.getBufferType(); + BufferType currBufferType = ttnnLayoutAttr.getBufferType(); // Get the current element type (i.e bf16/TileType etc) - Type currElementType = tensorConfig.getElementType(); + Type currElementType = ttnnLayoutAttr.getElementType(); - // Get the mem layout attribute (i.e interleaved/sharded or null in case of - // System) - TensorMemoryLayout currMemLayout = tensorConfig.getMemLayout(); + // Get mem layout. If the tensor is on host layout is null + TensorMemoryLayoutAttr currMemLayout = ttnnLayoutAttr.getMemLayout(); // Get element type that should be used in the new ttnn layout Type desiredElementType = @@ -181,7 +180,7 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, // the desired ones, we don't need to do anything if (currBufferType == desiredBufferType && currElementType == desiredElementType && - currMemLayout == desiredMemLayout) { + currMemLayout == desiredMemLayoutAttr) { return std::nullopt; } @@ -189,7 +188,7 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, // memory layout TTNNLayoutAttr desiredLayout = rewriter.getAttr( ty.getShape(), desiredElementType, desiredBufferType, - tensorConfig.getGrid(), desiredMemLayout, g_defaultCollapseDims); + ttnnLayoutAttr.getGrid(), desiredMemLayoutAttr, g_defaultCollapseDims); // If the input tensor is a constant or empty tensor, we can replace it with a // new tensor with the desired layout @@ -214,6 +213,28 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, .getResult(); } + // If the input tensor is an arange, we want to set the desired layout just + // like the other creation ops. However, a caveat is that in ttnn, arange is + // hardcoded to be ROW_MAJOR. So we must ensure that the layout we assign to + // it is ROW_MAJOR - and to make it tile layout we still must insert + // ToLayoutOp on its output. We can do this by setting the element type to + // ty.getElementType() in case desiredElementType is a TileType. + ttir::ArangeOp existingArange = input.getDefiningOp(); + if (existingArange) { + TTNNLayoutAttr arangeLayout = rewriter.getAttr( + ty.getShape(), ty.getElementType(), desiredBufferType, + ttnnLayoutAttr.getGrid(), desiredMemLayoutAttr, g_defaultCollapseDims); + input = + rewriter + .replaceOpWithNewOp( + existingArange, + mlir::RankedTensorType::get(ty.getShape(), ty.getElementType(), + arangeLayout), + existingArange.getStart(), existingArange.getEnd(), + existingArange.getStep(), existingArange.getArangeDimension()) + .getResult(); + } + // If the input tensor is not a constant or empty tensor, we need to create a // new tensor with the desired layout which will be used as the output of the // ToLayoutOp @@ -242,15 +263,34 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, utils::toTTTensorMemoryLayout(g_defaultMemoryLayout); tt::TensorMemoryLayout desiredMemoryLayout = getLegalTensorMemoryLayout( operandConstraint, desiredMemorySpace, ttMemoryLayout); - TensorMemoryLayout ttnnMemoryLayout = - utils::toTTNNTensorMemoryLayout(desiredMemoryLayout); + TensorMemoryLayoutAttr ttnnMemoryLayoutAttr; + if (desiredMemoryLayout != tt::TensorMemoryLayout::None) { + TensorMemoryLayout ttnnMemoryLayout = + utils::toTTNNTensorMemoryLayout(desiredMemoryLayout); + ttnnMemoryLayoutAttr = + TensorMemoryLayoutAttr::get(rewriter.getContext(), ttnnMemoryLayout); + } // Check if the tensor should be tiled bool tiled = !bitEnumContainsAny(operandConstraint, OperandConstraint::Scalar); return createToLayoutOp(rewriter, loc, input, desiredBufferType, - ttnnMemoryLayout, tiled); + ttnnMemoryLayoutAttr, tiled); +} + +static bool changeLayoutToHost(DestinationStyleOpInterface &op, + OpOperand &operand, PatternRewriter &rewriter) { + Location newLoc = appendInputSuffix(op.getLoc(), operand.getOperandNumber()); + std::optional layout = + createToLayoutOp(rewriter, newLoc, operand.get(), + BufferType::SystemMemory, nullptr, false /* tiled */); + if (layout.has_value()) { + rewriter.modifyOpInPlace( + op, [&]() { op->setOperand(operand.getOperandNumber(), *layout); }); + return true; + } + return false; } // Updates the layout of the operands of a TTIR ops which have DPS operands. @@ -278,6 +318,19 @@ class TTNNLayoutDPSOperandsRewriter // TTNN Conv2d moves input, weight, and bias from host to device // itself. Inserting the ToLayoutOp on these operands is thus problematic. if (mlir::isa(op.getOperation()) && !isResult) { + // For the weight input of the conv2d op, it specifically needs to be on + // host, so we create a host to layout op (issue + // https://github.com/tenstorrent/tt-mlir/issues/1528). + if (operand.getOperandNumber() == 1) { + modified = changeLayoutToHost(op, operand, rewriter); + } + continue; + } + + // If the operand is a BroadcastOp or a ToLayout op do not put a + // ToLayoutOp on its output + if (operand.get().getDefiningOp() || + operand.get().getDefiningOp()) { continue; } @@ -326,7 +379,7 @@ class TTNNLayoutFuncReturnRewriter appendInputSuffix(op.getLoc(), operand.getOperandNumber()); std::optional layout = createToLayoutOp( rewriter, newLoc, operand.get(), BufferType::SystemMemory, - TensorMemoryLayout::None, false /* tiled */); + nullptr /* tensorMemoryLayoutAttr */, false /* tiled */); if (layout.has_value()) { rewriter.modifyOpInPlace( op, [&]() { op.setOperand(operand.getOperandNumber(), *layout); }); diff --git a/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp new file mode 100644 index 0000000000..bba5d0bcd9 --- /dev/null +++ b/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp @@ -0,0 +1,413 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTNN/Transforms/Passes.h" + +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h" +#include "ttmlir/Dialect/TTNN/Types/Types.h" +#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h" +#include "ttmlir/Dialect/TTNN/Utils/Utils.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "ttmlir/Utils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace mlir::tt::ttnn { +#define GEN_PASS_DEF_TTNNWORKAROUNDS +#include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc" + +// Helper method to get the tensor layout attribute from the op operand. +static TTNNLayoutAttr getLayoutAttrFromOpOperand(OpOperand &opOperand) { + auto tensorType = mlir::cast(opOperand.get().getType()); + return mlir::cast(tensorType.getEncoding()); +} + +// Helper method to get the tensor layout attribute from the op result. +static TTNNLayoutAttr getLayoutAttrFromOpResult(OpResult &opResult) { + auto tensorType = mlir::cast(opResult.getType()); + return mlir::cast(tensorType.getEncoding()); +} + +// Helper method to get the element type for the given tensor layout and data. +static Type getElementType(MLIRContext *context, Layout tensorLayout, + DataType dataType) { + return tensorLayout == Layout::Tile + ? TileType::get(context, {ttnn::TILE_HEIGHT, ttnn::TILE_WIDTH}, + dataType) + : ttnn::utils::createRowMajorTypeFromDtype(context, dataType); +} + +// Helper method to insert a ToLayoutOp to convert the input operand to the +// desired tensor layout, buffer type and memory layout. +static mlir::Value +createToLayoutOp(wa::TTNNWorkaroundInterface &op, OpOperand &inputOperand, + PatternRewriter &rewriter, Layout targetTensorLayout, + BufferType targetTensorBufferType, + std::optional targetTensorMemoryLayout) { + TTNNLayoutAttr inputLayoutAttr = getLayoutAttrFromOpOperand(inputOperand); + + // Create element type based on tensor layout. + Type elementType = getElementType(rewriter.getContext(), targetTensorLayout, + inputLayoutAttr.getDataType()); + + // Create tensor memory layout attribute. + ttnn::TensorMemoryLayoutAttr outputMemLayoutAttr = + targetTensorMemoryLayout.has_value() + ? ttnn::TensorMemoryLayoutAttr::get(rewriter.getContext(), + targetTensorMemoryLayout.value()) + : nullptr; + + // Create the output memory config attribute. + ttnn::MemoryConfigAttr outputMemConfigAttr = ttnn::MemoryConfigAttr::get( + rewriter.getContext(), + ttnn::BufferTypeAttr::get(rewriter.getContext(), targetTensorBufferType), + ttnn::ShardSpecAttr::get( + op.getContext(), + ttnn::ShapeAttr::get(rewriter.getContext(), + inputLayoutAttr.getMemref().getShape())), + outputMemLayoutAttr); + + // Get the input operand type. + RankedTensorType inputOperandType = + mlir::cast(inputOperand.get().getType()); + + // Create a ToLayoutOp to convert the input operand to the desired + // tensor layout, buffer type and memory layout. + return rewriter + .create( + op.getLoc(), + ttnn::utils::createRankedTensorTypeWithEncoding( + inputOperandType, + inputLayoutAttr + .withElementType(rewriter.getContext(), elementType) + .withBufferType(rewriter.getContext(), targetTensorBufferType) + .withMemoryLayout(rewriter.getContext(), + outputMemLayoutAttr)), + inputOperand.get(), + LayoutAttr::get(rewriter.getContext(), targetTensorLayout), + DataTypeAttr::get(rewriter.getContext(), + inputLayoutAttr.getDataType()), + outputMemConfigAttr, + (targetTensorBufferType == ttnn::BufferType::SystemMemory) + ? nullptr + : utils::getOrInsertDevice(rewriter, op)) + ->getResult(0); +} + +// Helper method to apply workarounds to an input operand. This method inserts a +// ToLayoutOp with the specified tensor layout, buffer type, and memory layout. +// It returns true if the workarounds were successfully applied. +static bool workaroundInputOperand( + OpOperand &inputOperand, const wa::TTNNOperandWorkarounds &inputWorkaround, + PatternRewriter &rewriter, wa::TTNNWorkaroundInterface op) { + // Get the current input tensor layout, buffer type and memory layout from the + // input operand. + TTNNLayoutAttr inputLayoutAttr = getLayoutAttrFromOpOperand(inputOperand); + + // Apply the workarounds on the input operand workaround arguments + wa::WorkaroundResult inputWorkaroundResult = + applyWorkarounds(inputWorkaround, inputLayoutAttr); + + // If there were no modifications by workarounds, return false. + if (!inputWorkaroundResult.modified()) { + return false; + } + + // Apply the workarounds on the input operand by inserting the ToLayoutOp with + // the desired tensor layout, buffer type and memory layout. + mlir::Value insertedToLayoutOpValue = createToLayoutOp( + op, inputOperand, rewriter, + inputWorkaroundResult.targetTensorLayoutResult.first, + inputWorkaroundResult.targetTensorBufferTypeResult.first, + inputWorkaroundResult.targetTensorMemoryLayoutResult.first); + + // Insert to layout op between the current op and the input operand + // to convert the input operand to the desired tensor layout, buffer type. + rewriter.modifyOpInPlace(op, [&]() { + // Update the input operand with the new toLayout op operand. + op->setOperand(inputOperand.getOperandNumber(), insertedToLayoutOpValue); + }); + + return true; +} + +// Helper method to apply workarounds to output results. +// - For DPS results, this method only verifies that the output result matches +// the +// corresponding DPS destination operand. At this stage, DPS results should +// already be propagated. +// - For non-DPS operations, this method applies the necessary workarounds to +// the +// output result and returns true if the workarounds were successfully +// applied. +static bool workaroundOutputOperand( + OpResult &opResult, const wa::TTNNOperandWorkarounds &outputWorkaround, + PatternRewriter &rewriter, wa::TTNNWorkaroundInterface op) { + // Get the current output tensor layout, buffer type and memory layout from + // the input operand. + TTNNLayoutAttr opResultLayoutAttr = getLayoutAttrFromOpResult(opResult); + + // Apply the workarounds on the output result workaround arguments + wa::WorkaroundResult outputWorkaroundResult = + wa::applyWorkarounds(outputWorkaround, opResultLayoutAttr); + + // At this point, the DPS result should already be propagated, hence we only + // need to verify that the output workaround is in sync with the current DPS + // result. + assert(!(outputWorkaroundResult.modified() && + mlir::isa(op.getOperation())) && + "Output operand workarounds not supported for DPS ops"); + + // If there were no modifications by workarounds, return false. + if (!outputWorkaroundResult.modified()) { + return false; + } + + // Create the data type attribute. + Type elementType = + getElementType(rewriter.getContext(), + outputWorkaroundResult.targetTensorLayoutResult.first, + opResultLayoutAttr.getDataType()); + + // Get the input operand type. + RankedTensorType opResultType = + mlir::cast(opResult.getType()); + + // Create tensor memory layout attribute. + TensorMemoryLayoutAttr outputMemLayoutAttr = + outputWorkaroundResult.targetTensorMemoryLayoutResult.first.has_value() + ? ttnn::TensorMemoryLayoutAttr::get( + rewriter.getContext(), + outputWorkaroundResult.targetTensorMemoryLayoutResult.first + .value()) + : nullptr; + + // Create the new output result type with the updated tensor layout, buffer + // type and memory layout. + RankedTensorType newOutputResultType = + ttnn::utils::createRankedTensorTypeWithEncoding( + opResultType, + opResultLayoutAttr.withElementType(rewriter.getContext(), elementType) + .withBufferType( + rewriter.getContext(), + outputWorkaroundResult.targetTensorBufferTypeResult.first) + .withMemoryLayout(rewriter.getContext(), outputMemLayoutAttr)); + + // Update the type of result with applied workarounds. + rewriter.modifyOpInPlace(op, [&]() { + opResult.setType(newOutputResultType); + + // Some ops defines attributes with tensor layout, buffer type and memory + // layout, hence we need to update the attributes as well. For example, + // the empty op defines layout and memory_config attributes. + if (outputWorkaroundResult.targetTensorLayoutResult.second && + op->getAttrDictionary().get("layout")) { + LayoutAttr updatedLayoutAttr = rewriter.getAttr( + outputWorkaroundResult.targetTensorLayoutResult.first); + op->setAttr("layout", updatedLayoutAttr); + } + + if ((outputWorkaroundResult.targetTensorBufferTypeResult.second || + outputWorkaroundResult.targetTensorMemoryLayoutResult.second) && + op->getAttrDictionary().get("memory_config")) { + + MemoryConfigAttr currentMemoryConfig = + mlir::cast(op->getAttr("memory_config")); + + // Create the output memory config attribute. + // Check if the buffer type got updated. + if (outputWorkaroundResult.targetTensorBufferTypeResult.second) { + currentMemoryConfig = currentMemoryConfig.withBufferType( + rewriter.getContext(), + outputWorkaroundResult.targetTensorBufferTypeResult.first); + } + + // Check if the memory layout got updated. + if (outputWorkaroundResult.targetTensorMemoryLayoutResult.second) { + currentMemoryConfig = currentMemoryConfig.withMemoryLayout( + rewriter.getContext(), + outputWorkaroundResult.targetTensorMemoryLayoutResult.first + .value()); + } + + // Update the changed memory config attribute. + op->setAttr("memory_config", currentMemoryConfig); + } + }); + + return true; +} + +// Propagate the workaround changes for DPS input operands if they are applied +// in above graph transforms, either in a pattern for a current op, or in a +// pattern matched for a previous ops. +static bool propagateDpsInitChangesToDpsResults(wa::TTNNWorkaroundInterface &op, + PatternRewriter &rewriter) { + // Check if the op is a DPS op. + if (!mlir::isa(op.getOperation())) { + return false; + } + + bool modified = false; + + auto dpsOp = mlir::cast(op.getOperation()); + mlir::OperandRange dpsInits = dpsOp.getDpsInits(); + + // Iterate through all dps destination operands and propagate the changes if + // any. + for (size_t dpsInitIndex = 0; dpsInitIndex < dpsInits.size(); + dpsInitIndex++) { + OpOperand *dpsInit = dpsOp.getDpsInitOperand(dpsInitIndex); + OpResult tiedDpsResult = dpsOp.getTiedOpResult(dpsInit); + + // If the DPS destination is changed, update the DPS result as well. + if (tiedDpsResult.getType() != dpsInit->get().getType()) { + modified = true; + rewriter.modifyOpInPlace( + op, [&]() { tiedDpsResult.setType(dpsInit->get().getType()); }); + } + } + + return modified; +} + +// TTNNWorkaroundInterface rewriter applies workarounds to the operands of TTNN +// operations. TTNNWorkaroundInterface is an interface on TTNN_Op, so this +// pattern should match each op in the IR. +// +// The rewriter processes both input and output operands of TTNN operations: +// 1. **Input Operands**: The rewriter iterates through all input tensor +// operands and applies the necessary workarounds. +// - Workarounds are applied by inserting ToLayoutOp with the desired tensor +// layout, buffer type, and memory layout. +// 2. **DPS result propagation**: The rewriter propagates changes to tied DPS +// destination operands to ensure consistency with previous graph +// transformations, either in the current op match or previous op matches. +// 3. **Output Operands**: Output workarounds are applied only if the operation +// is not a DPS op. +// - At this stage, all DPS result changes should be propagated. An assertion +// ensures that the output result workaround matches +// the corresponding DPS output result. +// - Workarounds are applied by updating the output result type with the new +// tensor layout, buffer type, and memory layout. +// - For operations that define attributes with tensor layout, buffer type, +// and memory layout, these attributes are also updated. +// For example, the empty op defines layout and memory_config attributes. +class TTNNOperandsWorkaroundsRewriter + : public OpInterfaceRewritePattern { +public: + TTNNOperandsWorkaroundsRewriter(MLIRContext *ctx) + : OpInterfaceRewritePattern(ctx) {} + + LogicalResult matchAndRewrite(wa::TTNNWorkaroundInterface op, + PatternRewriter &rewriter) const final { + + // To layout op is a special case, we don't want to rewrite it. + if (mlir::isa(op.getOperation())) { + return failure(); + } + + bool modified = false; + // Get the operands workarounds for the current operation. + wa::TTNNOperandsWorkarounds operandsWorkarounds = + op.getOperandsWorkarounds(); + + // Filter out all the input tensor operands. + auto inputTensorsOperands = + llvm::make_filter_range(op->getOpOperands(), [](OpOperand &v) { + return ttmlir::utils::isRankedTensor(v.get()); + }); + + // Apply workarounds to all input tensor operands. + llvm::for_each( + llvm::zip_equal(inputTensorsOperands, + operandsWorkarounds.getInputOperandWorkarounds()), + [&](std::tuple + pair) { + modified = std::get<1>(pair).hasAnyWorkaround() && + workaroundInputOperand(std::get<0>(pair), + std::get<1>(pair), rewriter, op); + }); + + // Propagate the workaround changes for DPS input operands to DPS results if + // they are applied in above graph transforms, either in a pattern for a + // current op, or in a pattern matched for a previous ops. + modified |= propagateDpsInitChangesToDpsResults(op, rewriter); + + // Filter out all the output tensor results. + auto outputTensorResults = + llvm::make_filter_range(op->getOpResults(), [](OpResult v) { + return ttmlir::utils::isRankedTensor(v); + }); + + // Apply workarounds to all output tensor results. + llvm::for_each( + llvm::zip_equal(outputTensorResults, + operandsWorkarounds.getOutputOperandWorkarounds()), + [&](std::tuple + pair) { + modified |= std::get<1>(pair).hasAnyWorkaround() && + workaroundOutputOperand(std::get<0>(pair), + std::get<1>(pair), rewriter, op); + }); + + // Return success if the transformations were applied. + return modified ? success() : failure(); + } +}; + +// Pass to apply workarounds to the operands of TTNN operations. +class TTNNWorkarounds : public impl::TTNNWorkaroundsBase { +public: + using impl::TTNNWorkaroundsBase::TTNNWorkaroundsBase; + + void runOnOperation() final { + { + // Placeholder for workaround decomposition patterns. + } + { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + FrozenRewritePatternSet patternSet(std::move(patterns)); + GreedyRewriteConfig config = GreedyRewriteConfig(); + // This configuration specifies that the rewriter should traverse the IR + // in a top-down order. + config.useTopDownTraversal = true; + // This configuration specifies the maximum number of iterations the + // rewriter will perform on the IR. The rewriter will iterate through the + // IR until a fixpoint is reached. All workarounds should be applied + // during the first iteration. If the workarounds are not applied in the + // first iteration, it indicates a bug in the workarounds implementation. + // Although the workarounds are applied in the first iteration, the + // rewriter must iterate through the IR once more to confirm that the + // fixpoint is reached. If the fixpoint is not reached in the second + // iteration, it indicates a bug in the workarounds implementation. + config.maxIterations = 2; + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet, + config))) { + signalPassFailure(); + return; + } + } + } +}; +} // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Utils/CMakeLists.txt b/lib/Dialect/TTNN/Utils/CMakeLists.txt index f49f829e6f..cad244c0b5 100644 --- a/lib/Dialect/TTNN/Utils/CMakeLists.txt +++ b/lib/Dialect/TTNN/Utils/CMakeLists.txt @@ -1,6 +1,9 @@ add_mlir_dialect_library(TTMLIRTTNNUtils - Utils.cpp OptimizerOverrides.cpp + PassOverrides.cpp + TransformUtils.cpp + Utils.cpp + ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/TTNN diff --git a/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp b/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp index c888dff6bf..157c1e50d3 100644 --- a/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp +++ b/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp @@ -6,183 +6,205 @@ namespace mlir::tt::ttnn { -namespace { -std::optional> -parseGrid(StringRef param, char gridSeparator, llvm::cl::Option &opt) { - SmallVector gridParts; - param.split(gridParts, gridSeparator); - if (gridParts.size() == 2) { - int64_t gridX, gridY; - if (gridParts[0].getAsInteger(10, gridX) || - gridParts[1].getAsInteger(10, gridY)) { - opt.error("Invalid grid size: " + param); - return std::nullopt; - } - return SmallVector{gridX, gridY}; - } - return std::nullopt; +void OptimizerOverridesHandler::setEnableOptimizer(bool value) { + enableOptimizer = value; +} + +void OptimizerOverridesHandler::setMemoryReconfig(bool value) { + enableMemoryReconfig = value; +} +void OptimizerOverridesHandler::setEnableMemoryLayoutAnalysis(bool value) { + enableMemoryLayoutAnalysis = value; +} +void OptimizerOverridesHandler::setEnableMemoryLayoutAnalysisPolicy( + bool value) { + enableMemoryLayoutAnalysisPolicy = value; +} +void OptimizerOverridesHandler::setMemoryLayoutAnalysisPolicy( + MemoryLayoutAnalysisPolicyType value) { + memoryLayoutAnalysisPolicy = value; } -} // namespace -bool OutputLayoutOverrideParser::parse( - llvm::cl::Option &opt, StringRef argName, StringRef arg, +void OptimizerOverridesHandler::setInputLayoutOverrides( + llvm::StringMap &value) { + inputLayoutOverrides = value; +} +void OptimizerOverridesHandler::setOutputLayoutOverrides( llvm::StringMap &value) { - SmallVector opOverrideList; - constexpr size_t kvPairSize = 2; - constexpr size_t iOpName = 0; - constexpr size_t iLayoutOverrideParams = 1; - constexpr char opSeparator = ','; - constexpr char opNameSeparator = '='; - constexpr char paramSeparator = ':'; - constexpr char gridSeparator = 'x'; - - arg.split(opOverrideList, opSeparator); - for (const StringRef override : opOverrideList) { - SmallVector opOverrideParts; - override.split(opOverrideParts, opNameSeparator); - if (opOverrideParts.size() != kvPairSize) { - opt.error("Invalid format for override grid sizes: " + override); - return true; - } + outputLayoutOverrides = value; +} - SmallVector layoutParamParts; - opOverrideParts[iLayoutOverrideParams].split(layoutParamParts, - paramSeparator); - - OutputLayoutOverrideParams params; - - for (const StringRef ¶m : layoutParamParts) { - if (auto grid = parseGrid(param, gridSeparator, opt)) { - if (params.grid.has_value()) { - opt.error("Multiple grid parameters provided: " + param); - return true; - } - params.grid = grid; - } else if (auto bufferType = symbolizeBufferType(param)) { - if (params.bufferType.has_value()) { - opt.error("Multiple buffer type parameters provided: " + param); - return true; - } - params.bufferType = bufferType; - } else if (auto tensorMemoryLayout = symbolizeTensorMemoryLayout(param)) { - if (params.tensorMemoryLayout.has_value()) { - opt.error("Multiple tensor memory layout parameters provided: " + - param); - return true; - } - params.tensorMemoryLayout = tensorMemoryLayout; - } else if (auto memoryLayout = mlir::tt::ttnn::symbolizeLayout(param)) { - if (params.memoryLayout.has_value()) { - opt.error("Multiple memory layout parameters provided: " + param); - return true; - } - params.memoryLayout = memoryLayout; - } else if (auto dataType = mlir::tt::DataTypeStringToEnum(param)) { - if (params.dataType.has_value()) { - opt.error("Multiple data type parameters provided: " + param); - return true; - } - params.dataType = dataType; - } else { - opt.error("Invalid layout parameter: " + param); - return true; - } - } +void OptimizerOverridesHandler::setSystemDescPath(std::string value) { + systemDescPath = value; +} +void OptimizerOverridesHandler::setMaxLegalLayouts(int64_t value) { + maxLegalLayouts = value; +} +void OptimizerOverridesHandler::setMeshShape(std::vector value) { + meshShape = value; +} + +bool OptimizerOverridesHandler::getEnableOptimizer() const { + return enableOptimizer; +} + +bool OptimizerOverridesHandler::getMemoryReconfig() const { + return enableMemoryReconfig; +} +bool OptimizerOverridesHandler::getEnableMemoryLayoutAnalysis() const { + return enableMemoryLayoutAnalysis; +} +bool OptimizerOverridesHandler::getEnableMemoryLayoutAnalysisPolicy() const { + return enableMemoryLayoutAnalysisPolicy; +} +MemoryLayoutAnalysisPolicyType +OptimizerOverridesHandler::getMemoryLayoutAnalysisPolicy() const { + return memoryLayoutAnalysisPolicy; +} + +std::string OptimizerOverridesHandler::getSystemDescPath() const { + return systemDescPath; +} +int64_t OptimizerOverridesHandler::getMaxLegalLayouts() const { + return maxLegalLayouts; +} +std::vector OptimizerOverridesHandler::getMeshShape() const { + return meshShape; +} + +llvm::StringMap +OptimizerOverridesHandler::getInputLayoutOverrides() const { + return inputLayoutOverrides; +} +llvm::StringMap +OptimizerOverridesHandler::getOutputLayoutOverrides() const { + return outputLayoutOverrides; +} - value[opOverrideParts[iOpName]] = params; +std::unordered_map +OptimizerOverridesHandler::getInputLayoutOverridesPybindWrapper() const { + std::unordered_map + inputLayoutOverridesWrapper; + for (auto &entry : inputLayoutOverrides) { + inputLayoutOverridesWrapper[entry.getKey().str()] = entry.getValue(); } - return false; -} - -void OutputLayoutOverrideParser::print( - llvm::raw_ostream &os, - const llvm::StringMap &value) { - os << "override-output-layout="; - size_t count = 0; - for (const auto &entry : value) { - os << entry.getKey() << "="; - const OutputLayoutOverrideParams ¶ms = entry.getValue(); - // Print grid values - for (size_t i = 0; i < params.grid.value().size(); ++i) { - os << params.grid.value()[i]; - if (i < params.grid.value().size() - 1) { - os << "x"; - } - } - // Print memory space and memory layout - os << ":" << mlir::tt::ttnn::stringifyBufferType(params.bufferType.value()); - os << ":" - << mlir::tt::ttnn::stringifyTensorMemoryLayout( - params.tensorMemoryLayout.value()); - os << ":" << mlir::tt::ttnn::stringifyLayout(params.memoryLayout.value()); - os << ":" << mlir::tt::DataTypeEnumToString(params.dataType.value()); - if (++count < value.size()) { - os << ","; - } + return inputLayoutOverridesWrapper; +} + +std::unordered_map +OptimizerOverridesHandler::getOutputLayoutOverridesPybindWrapper() const { + std::unordered_map + outputLayoutOverridesWrapper; + for (auto &entry : outputLayoutOverrides) { + outputLayoutOverridesWrapper[entry.getKey().str()] = entry.getValue(); } - os << "\n"; + return outputLayoutOverridesWrapper; } -bool InputLayoutOverrideParser::parse( - llvm::cl::Option &opt, StringRef argName, StringRef arg, - llvm::StringMap &value) { - SmallVector opOverrideList; - constexpr size_t kvPairSize = 2; - constexpr size_t iOpName = 0; - constexpr size_t iOperands = 1; - constexpr char opSeparator = ','; - constexpr char opNameSeparator = '='; - constexpr char opParamSeparator = ':'; - - arg.split(opOverrideList, opSeparator); - for (const StringRef override : opOverrideList) { - SmallVector opOverrideParts; - override.split(opOverrideParts, opNameSeparator); - if (opOverrideParts.size() != kvPairSize) { - opt.error("Invalid format for input layouts override: " + override); - return true; - } +std::string OptimizerOverridesHandler::toString() const { - SmallVector operandIndexes; - SmallVector operandIndexParts; - - // Parse operand indexes. - opOverrideParts[iOperands].split(operandIndexParts, opParamSeparator); - for (const StringRef operandIndexPart : operandIndexParts) { - int64_t operandIndexValue; - if (operandIndexPart.getAsInteger(10 /*Radix*/, operandIndexValue)) { - opt.error("Invalid operand index: " + operandIndexPart); - return true; - } - operandIndexes.push_back(operandIndexValue); - } + std::string options = ""; - // Set parsed op overrides. - value[opOverrideParts[iOpName]] = - InputLayoutOverrideParams{std::move(operandIndexes)}; + if (enableOptimizer) { + options += OptionNames::optimizerPassEnabled.str() + "=true "; } - return false; -} - -void InputLayoutOverrideParser::print( - llvm::raw_ostream &os, - const llvm::StringMap &value) { - os << "insert-memreconfig="; - size_t count = 0; - for (const auto &entry : value) { - os << entry.getKey() << "="; - const InputLayoutOverrideParams ¶ms = entry.getValue(); - for (int64_t operandIdx : params.operandIdxes) { - os << operandIdx - << (operandIdx < static_cast(params.operandIdxes.size()) - 1 - ? ':' - : char()); - } - if (++count < value.size()) { - os << ","; + + if (enableMemoryReconfig) { + options += OptionNames::memReconfigEnabled.str() + "=true "; + } + + if (enableMemoryLayoutAnalysis) { + options += OptionNames::memoryLayoutAnalysisEnabled.str() + "=true "; + } + + if (enableMemoryLayoutAnalysisPolicy) { + options += OptionNames::memoryLayoutAnalysisPolicy.str() + "=" + + MemoryLayoutAnalysisPolicyTypeParser::toString( + memoryLayoutAnalysisPolicy) + + " "; + } + + // Create input layout overrides. + // Example: + // insert-memreconfig=input0=0:1,input1=0,input2=0:1:2 + if (inputLayoutOverrides.size() > 0) { + options += OptionNames::overrideInputLayout.str() + "=" + + InputLayoutOverrideParser::toString(inputLayoutOverrides) + " "; + } + + // Create output layout overrides. + // Example: + // override-output-layout=op1=2x2:dram:interleaved:tile:fp32,op2=4x4:l1:block_sharded:row_major:fp16 + // Example: + // override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32" + if (outputLayoutOverrides.size() > 0) { + options += OptionNames::overrideOutputLayout.str() + "=" + + OutputLayoutOverrideParser::toString(outputLayoutOverrides) + + " "; + } + + if (systemDescPath.size() > 0) { + options += OptionNames::systemDescPath.str() + "=" + systemDescPath + " "; + } + + if (maxLegalLayouts > 0) { + options += OptionNames::maxLegalLayouts.str() + "=" + + std::to_string(maxLegalLayouts) + " "; + } + + if (meshShape.size() > 0) { + options += OptionNames::meshShape.str() + "="; + for (int64_t meshShapeValue : meshShape) { + options += std::to_string(meshShapeValue) + ","; } + // Remove the last comma. + options.pop_back(); } - os << "\n"; + + if (options[options.size() - 1] == ' ') { + options.pop_back(); + } + + return options; +} + +void OptimizerOverridesHandler::addInputLayoutOverride( + StringRef opName, InputLayoutOverrideParams params) { + inputLayoutOverrides[opName] = params; +} +void OptimizerOverridesHandler::addInputLayoutOverride( + StringRef opName, SmallVector &operandIdxes) { + inputLayoutOverrides[opName] = + InputLayoutOverrideParams{std::move(operandIdxes)}; +} +void OptimizerOverridesHandler::addOutputLayoutOverride( + StringRef opName, OutputLayoutOverrideParams params) { + outputLayoutOverrides[opName] = params; +} +void OptimizerOverridesHandler::addOutputLayoutOverride( + StringRef opName, SmallVector &grid, BufferType bufferType, + TensorMemoryLayout tensorMemoryLayout, tt::ttnn::Layout memoryLayout, + tt::DataType dataType) { + outputLayoutOverrides[opName] = OutputLayoutOverrideParams{ + std::move(grid), bufferType, tensorMemoryLayout, memoryLayout, dataType}; +} + +void OptimizerOverridesHandler::addInputLayoutOverridePybindWrapper( + std::string opName, std::vector &operandIdxes) { + StringRef opNameStringRef(opName); + SmallVector operandIdxesSmallVector(operandIdxes.begin(), + operandIdxes.end()); + addInputLayoutOverride(opNameStringRef, operandIdxesSmallVector); +} + +void OptimizerOverridesHandler::addOutputLayoutOverridePybindWrapper( + std::string opName, std::vector &grid, BufferType bufferType, + TensorMemoryLayout tensorMemoryLayout, tt::ttnn::Layout memoryLayout, + tt::DataType dataType) { + StringRef opNameStringRef(opName); + SmallVector gridSmallVector(grid.begin(), grid.end()); + addOutputLayoutOverride(opNameStringRef, gridSmallVector, bufferType, + tensorMemoryLayout, memoryLayout, dataType); } } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Utils/PassOverrides.cpp b/lib/Dialect/TTNN/Utils/PassOverrides.cpp new file mode 100644 index 0000000000..ad59ea91cb --- /dev/null +++ b/lib/Dialect/TTNN/Utils/PassOverrides.cpp @@ -0,0 +1,214 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" + +namespace mlir::tt::ttnn { + +namespace { +std::optional> +parseGrid(StringRef param, char gridSeparator, llvm::cl::Option &opt) { + SmallVector gridParts; + param.split(gridParts, gridSeparator); + if (gridParts.size() == 2) { + int64_t gridX, gridY; + if (gridParts[0].getAsInteger(10, gridX) || + gridParts[1].getAsInteger(10, gridY)) { + opt.error("Invalid grid size: " + param); + return std::nullopt; + } + return SmallVector{gridX, gridY}; + } + return std::nullopt; +} +} // namespace + +bool OutputLayoutOverrideParser::parse( + llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value) { + SmallVector opOverrideList; + constexpr size_t kvPairSize = 2; + constexpr size_t iOpName = 0; + constexpr size_t iLayoutOverrideParams = 1; + constexpr char opSeparator = ','; + constexpr char opNameSeparator = '='; + constexpr char paramSeparator = ':'; + constexpr char gridSeparator = 'x'; + + arg.split(opOverrideList, opSeparator); + for (const StringRef override : opOverrideList) { + SmallVector opOverrideParts; + override.split(opOverrideParts, opNameSeparator); + if (opOverrideParts.size() != kvPairSize) { + opt.error("Invalid format for override grid sizes: " + override); + return true; + } + + SmallVector layoutParamParts; + opOverrideParts[iLayoutOverrideParams].split(layoutParamParts, + paramSeparator); + + OutputLayoutOverrideParams params; + + for (const StringRef ¶m : layoutParamParts) { + if (auto grid = parseGrid(param, gridSeparator, opt)) { + if (params.grid.has_value()) { + opt.error("Multiple grid parameters provided: " + param); + return true; + } + params.grid = grid; + } else if (auto bufferType = symbolizeBufferType(param)) { + if (params.bufferType.has_value()) { + opt.error("Multiple buffer type parameters provided: " + param); + return true; + } + params.bufferType = bufferType; + } else if (auto tensorMemoryLayout = symbolizeTensorMemoryLayout(param)) { + if (params.tensorMemoryLayout.has_value()) { + opt.error("Multiple tensor memory layout parameters provided: " + + param); + return true; + } + params.tensorMemoryLayout = tensorMemoryLayout; + } else if (auto memoryLayout = mlir::tt::ttnn::symbolizeLayout(param)) { + if (params.memoryLayout.has_value()) { + opt.error("Multiple memory layout parameters provided: " + param); + return true; + } + params.memoryLayout = memoryLayout; + } else if (auto dataType = mlir::tt::DataTypeStringToEnum(param)) { + if (params.dataType.has_value()) { + opt.error("Multiple data type parameters provided: " + param); + return true; + } + params.dataType = dataType; + } else { + opt.error("Invalid layout parameter: " + param); + return true; + } + } + + value[opOverrideParts[iOpName]] = params; + } + return false; +} + +std::string OutputLayoutOverrideParser::toString( + const llvm::StringMap &value) { + std::string res; + size_t count = 0; + for (const auto &entry : value) { + res += std::string(entry.getKey()) + "="; + const OutputLayoutOverrideParams ¶ms = entry.getValue(); + + // Print grid values + if (params.grid.has_value()) { + for (size_t i = 0; i < params.grid.value().size(); ++i) { + res += std::to_string(params.grid.value()[i]); + if (i < params.grid.value().size() - 1) { + res += "x"; + } + } + } + // Print memory space and memory layout + if (params.bufferType.has_value()) { + res += ":" + std::string(mlir::tt::ttnn::stringifyBufferType( + params.bufferType.value())); + } + if (params.tensorMemoryLayout.has_value()) { + res += ":" + std::string(mlir::tt::ttnn::stringifyTensorMemoryLayout( + params.tensorMemoryLayout.value())); + } + if (params.memoryLayout.has_value()) { + res += ":" + std::string(mlir::tt::ttnn::stringifyLayout( + params.memoryLayout.value())); + } + if (params.dataType.has_value()) { + res += ":" + std::string( + mlir::tt::DataTypeEnumToString(params.dataType.value())); + } + if (++count < value.size()) { + res += ","; + } + } + return res; +} + +void OutputLayoutOverrideParser::print( + llvm::raw_ostream &os, + const llvm::StringMap &value) { + os << "override-output-layout="; + os << OutputLayoutOverrideParser::toString(value); + os << "\n"; +} + +bool InputLayoutOverrideParser::parse( + llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value) { + SmallVector opOverrideList; + constexpr size_t kvPairSize = 2; + constexpr size_t iOpName = 0; + constexpr size_t iOperands = 1; + constexpr char opSeparator = ','; + constexpr char opNameSeparator = '='; + constexpr char opParamSeparator = ':'; + + arg.split(opOverrideList, opSeparator); + for (const StringRef override : opOverrideList) { + SmallVector opOverrideParts; + override.split(opOverrideParts, opNameSeparator); + if (opOverrideParts.size() != kvPairSize) { + opt.error("Invalid format for input layouts override: " + override); + return true; + } + + SmallVector operandIndexes; + SmallVector operandIndexParts; + + // Parse operand indexes. + opOverrideParts[iOperands].split(operandIndexParts, opParamSeparator); + for (const StringRef operandIndexPart : operandIndexParts) { + int64_t operandIndexValue; + if (operandIndexPart.getAsInteger(10 /*Radix*/, operandIndexValue)) { + opt.error("Invalid operand index: " + operandIndexPart); + return true; + } + operandIndexes.push_back(operandIndexValue); + } + + // Set parsed op overrides. + value[opOverrideParts[iOpName]] = + InputLayoutOverrideParams{std::move(operandIndexes)}; + } + return false; +} + +std::string InputLayoutOverrideParser::toString( + const llvm::StringMap &value) { + std::string res; + size_t count = 0; + for (const auto &entry : value) { + res += std::string(entry.getKey()) + "="; + const InputLayoutOverrideParams ¶ms = entry.getValue(); + for (int64_t operandIdx : params.operandIdxes) { + res += std::to_string(operandIdx) + ":"; + } + // Remove the last colon. + res.pop_back(); + if (++count < value.size()) { + res += ","; + } + } + return res; +} + +void InputLayoutOverrideParser::print( + llvm::raw_ostream &os, + const llvm::StringMap &value) { + os << "insert-memreconfig="; + os << InputLayoutOverrideParser::toString(value); + os << "\n"; +} + +} // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Utils/TransformUtils.cpp b/lib/Dialect/TTNN/Utils/TransformUtils.cpp new file mode 100644 index 0000000000..44b01e91b3 --- /dev/null +++ b/lib/Dialect/TTNN/Utils/TransformUtils.cpp @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h" + +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" + +namespace mlir::tt::ttnn::utils { +// Gets or inserts a GetDeviceOp at the top of the current block of the given +// operation. +Value getOrInsertDevice(PatternRewriter &rewriter, Operation *op) { + Block *block = op->getBlock(); + for (auto &op : block->getOperations()) { + if (auto deviceOp = dyn_cast(op)) { + return deviceOp.getResult(); + } + } + + DeviceAttr deviceAttr = getCurrentScopeDevice(op); + auto currentInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(block, block->begin()); + auto deviceOp = rewriter.create( + op->getLoc(), rewriter.getType(deviceAttr), + ttnn::MeshShapeAttr::get(op->getContext(), 1, 1)); + rewriter.restoreInsertionPoint(currentInsertionPoint); + return deviceOp.getResult(); +} +} // namespace mlir::tt::ttnn::utils diff --git a/lib/Dialect/TTNN/Utils/Utils.cpp b/lib/Dialect/TTNN/Utils/Utils.cpp index a0736219e6..514e17e521 100644 --- a/lib/Dialect/TTNN/Utils/Utils.cpp +++ b/lib/Dialect/TTNN/Utils/Utils.cpp @@ -38,11 +38,9 @@ mlir::tt::ttnn::TensorMemoryLayout toTTNNTensorMemoryLayout( return ttnn::TensorMemoryLayout::BlockSharded; case ::mlir::tt::TensorMemoryLayout::SingleBank: return ttnn::TensorMemoryLayout::SingleBank; - case ::mlir::tt::TensorMemoryLayout::None: - return ttnn::TensorMemoryLayout::None; + default: + llvm_unreachable("Unknown TensorMemoryLayout"); } - - llvm_unreachable("Unknown TensorMemoryLayout"); } mlir::tt::TensorMemoryLayout toTTTensorMemoryLayout( @@ -59,9 +57,9 @@ mlir::tt::TensorMemoryLayout toTTTensorMemoryLayout( return ::mlir::tt::TensorMemoryLayout::BlockSharded; case ttnn::TensorMemoryLayout::SingleBank: return ::mlir::tt::TensorMemoryLayout::SingleBank; - case ttnn::TensorMemoryLayout::None: - return ::mlir::tt::TensorMemoryLayout::None; } + + llvm_unreachable("Unknown TensorMemoryLayout"); } mlir::tt::MemorySpace @@ -82,18 +80,6 @@ toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType) { llvm_unreachable("Unknown MemorySpace"); } -DataType getDataTypeFromMemRef(mlir::MemRefType memref) { - Type elementType = memref.getElementType(); - DataType dtype = DataType::Float32; - if (llvm::isa(elementType)) { - auto tileType = mlir::cast(elementType); - dtype = tileType.getDataType(); - } else { - dtype = elementTypeToDataType(elementType); - } - return dtype; -} - Layout getLayoutFromMemRef(mlir::MemRefType memref) { ttnn::Layout ttnnLayoutEnum = ttnn::Layout::RowMajor; Type elementType = memref.getElementType(); @@ -134,4 +120,12 @@ Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context, DataType dtype) { } } +// Helper method to create a RankedTensorType with the given encoding +RankedTensorType +createRankedTensorTypeWithEncoding(RankedTensorType tensorType, + ttnn::TTNNLayoutAttr encoding) { + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); +} + } // namespace mlir::tt::ttnn::utils diff --git a/lib/OpModel/CMakeLists.txt b/lib/OpModel/CMakeLists.txt new file mode 100644 index 0000000000..9c34667d09 --- /dev/null +++ b/lib/OpModel/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TTNN) diff --git a/lib/OpModel/TTNN/CMakeLists.txt b/lib/OpModel/TTNN/CMakeLists.txt new file mode 100644 index 0000000000..094b9f1ddd --- /dev/null +++ b/lib/OpModel/TTNN/CMakeLists.txt @@ -0,0 +1,40 @@ +set(LIB_NAME TTNNOpModelLib) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(SOURCES + TTNNOpModelLib.cpp +) +add_library(${LIB_NAME} STATIC ${SOURCES}) + +message(STATUS "TTMLIR_ENABLE_OP_MODEL[${TTMLIR_ENABLE_OP_MODEL}]") +if (TTMLIR_ENABLE_OPMODEL) + # Link to tt-metal libs and include directories + target_include_directories(${LIB_NAME} PUBLIC "$") + target_link_libraries(${LIB_NAME} PUBLIC TTNN_LIBRARY TTMETAL_LIBRARY) + target_compile_definitions(${LIB_NAME} PUBLIC TTMLIR_ENABLE_OPMODEL) +else() + # link stubs implementation when op model library is disabled + message(WARNING "TTNNOpModelLib is disabled. The optimizer will not achieve optimal performance.") +endif() + +# Specify the include directories for the library +target_include_directories(${LIB_NAME} + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/ + ${PROJECT_SOURCE_DIR}/include/ttmlir/OpModel/TTNN/) + + +# Add TTNNOpModelLib to the export set +install(TARGETS ${LIB_NAME} + EXPORT TTNNOpModelLibTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin + INCLUDES DESTINATION include) + +# Export the targets +export(EXPORT TTNNOpModelLibTargets + FILE "${CMAKE_CURRENT_BINARY_DIR}/TTNNOpModelLibTargets.cmake" + NAMESPACE TTNN::) diff --git a/lib/OpModel/TTNN/TTNNOpModelLib.cpp b/lib/OpModel/TTNN/TTNNOpModelLib.cpp new file mode 100644 index 0000000000..87bfc04150 --- /dev/null +++ b/lib/OpModel/TTNN/TTNNOpModelLib.cpp @@ -0,0 +1,183 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "TTNNOpModel.h" + +#ifdef TTMLIR_ENABLE_OPMODEL +#include "TTNNOpModelLib_Impl.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" + +#include +#include + +#include +#include +#endif // TTMLIR_ENABLE_OPMODEL + +namespace mlir::tt::op_model::ttnn { + +#ifdef TTMLIR_ENABLE_OPMODEL +// alias to a common tt_metal types +using DataType = ::tt::tt_metal::DataType; +using Layout = ::tt::tt_metal::Layout; +using CoreRange = ::tt::tt_metal::CoreRange; +using CoreRangeSet = ::tt::tt_metal::CoreRangeSet; +using CoreCoord = ::tt::tt_metal::CoreCoord; +using ShardSpec = ::tt::tt_metal::ShardSpec; +using ShardOrientation = ::tt::tt_metal::ShardOrientation; +using TensorMemoryLayout = ::tt::tt_metal::TensorMemoryLayout; +using MemoryConfig = ::tt::tt_metal::MemoryConfig; + +namespace detail { + +DataType getDataType(const mlir::MemRefType &memref) { + + auto dataType = elementTypeToDataType(memref.getElementType()); + + switch (dataType) { + case tt::DataType::Float32: + return DataType::FLOAT32; + case tt::DataType::BFloat16: + return DataType::BFLOAT16; + case tt::DataType::BFP_BFloat8: + return DataType::BFLOAT8_B; + case tt::DataType::BFP_BFloat4: + return DataType::BFLOAT4_B; + case tt::DataType::UInt32: + return DataType::UINT32; + case tt::DataType::UInt16: + return DataType::UINT16; + case tt::DataType::UInt8: + return DataType::UINT8; + default: + throw std::runtime_error("Invalid element type"); + } +} + +::ttnn::SimpleShape getTensorShape(const mlir::MemRefType &memref) { + ::tt::tt_metal::SmallVector small_vector_shape( + memref.getShape().begin(), memref.getShape().end()); + return ::ttnn::SimpleShape(small_vector_shape); +} + +const std::array +getShardShape(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + const auto layoutShardTile = layout.getShardShape(); + + if (layoutShardTile.size() != 2) { + llvm::errs() << "ERROR: layout_shard_tile.size() != 2\n"; + return {0, 0}; + } + + std::array shardShape; + shardShape[0] = layoutShardTile[0]; + shardShape[1] = layoutShardTile[1]; + return shardShape; +} + +Layout getTensorLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + return layout.isTiled() ? Layout::TILE : Layout::ROW_MAJOR; +} + +CoreRangeSet getCoreRangeSet(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + // TODO(mbezulj): handle more complex grid shapes + // assuming grid shape is one rect starting at (0,0) + + const auto layoutGrid = layout.getGrid(); + + const auto layoutGridShape = layoutGrid.getShape(); + if (layoutGridShape.size() != 2) { + llvm::errs() << "ERROR: layout_grid.getShape().size() == 2\n"; + return {}; + } + + return CoreRangeSet(CoreRange(CoreCoord(0, layoutGridShape[0]), + CoreCoord(0, layoutGridShape[1]))); +} + +std::optional +layout_get_shard_spec(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + // tt_ShardOrientation is not part of ttnn::TTNNLayoutAttr; + // defaulting to ROW_MAJOR. TODO: figure out if we need to expose this + return isShardedMemoryLayout(layout.getMemLayout()) + ? std::make_optional(ShardSpec(getCoreRangeSet(layout), + getShardShape(layout), + ShardOrientation::ROW_MAJOR, false)) + : std::nullopt; +} + +::tt::tt_metal::BufferType getBufferType(const mlir::MemRefType &memref) { + auto memorySpace = + mlir::cast(memref.getMemorySpace()).getValue(); + + switch (memorySpace) { + case tt::MemorySpace::DeviceDRAM: + return ::tt::tt_metal::BufferType::DRAM; + case tt::MemorySpace::DeviceL1: + return ::tt::tt_metal::BufferType::L1; + default: // TODO(mbezulj): handle other memory spaces + throw std::runtime_error("Unsupported memory space"); + } +} + +::tt::tt_metal::TensorMemoryLayout +getTensorMemoryLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + auto tensorMemoryLayout = layout.getMemLayout(); + + switch (tensorMemoryLayout) { + case mlir::tt::ttnn::TensorMemoryLayout::Interleaved: + return ::tt::tt_metal::TensorMemoryLayout::INTERLEAVED; + case mlir::tt::ttnn::TensorMemoryLayout::SingleBank: + return ::tt::tt_metal::TensorMemoryLayout::SINGLE_BANK; + case mlir::tt::ttnn::TensorMemoryLayout::HeightSharded: + return ::tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED; + case mlir::tt::ttnn::TensorMemoryLayout::WidthSharded: + return ::tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED; + case mlir::tt::ttnn::TensorMemoryLayout::BlockSharded: + return ::tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED; + default: + throw std::runtime_error("Unsupported tensor memory layout"); + } +} + +::tt::tt_metal::MemoryConfig +getMemoryConfig(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + + auto tensorMemoryLayout = getTensorMemoryLayout(layout); + auto bufferType = getBufferType(layout.getMemref()); + + auto shardSpec = layout_get_shard_spec(layout); + return ::tt::tt_metal::MemoryConfig(tensorMemoryLayout, bufferType, + shardSpec); +} + +} // namespace detail +#endif // TTMLIR_ENABLE_OPMODEL + +//===----------------------------------------------------------------------===// +// ReluOp +//===----------------------------------------------------------------------===// + +bool ReluOpInterface::isLegal( + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { + +#ifdef TTMLIR_ENABLE_OPMODEL + return true; // to wire into tt-metal with the next uplift +#else + return true; +#endif // TTMLIR_ENABLE_OPMODEL +} + +std::tuple ReluOpInterface::getOpL1Usage( + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { +#ifdef TTMLIR_ENABLE_OPMODEL + return std::make_tuple(0, 0, 0); // to wire into tt-metal with the next uplift +#else + return std::make_tuple(0, 0, 0); +#endif // TTMLIR_ENABLE_OPMODEL +} + +} // namespace mlir::tt::op_model::ttnn diff --git a/lib/OpModel/TTNN/TTNNOpModelLib_Impl.h b/lib/OpModel/TTNN/TTNNOpModelLib_Impl.h new file mode 100644 index 0000000000..ed39d881a9 --- /dev/null +++ b/lib/OpModel/TTNN/TTNNOpModelLib_Impl.h @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_OPMODEL_TTNN_TTNNOPMODELLIB_IMPL_H +#define TTMLIR_OPMODEL_TTNN_TTNNOPMODELLIB_IMPL_H + +// This header resolves tt-metal warnings that would otherwise be treated as +// errors in the MLIR build. Ensure that this is the only place where tt-metal +// headers are included. + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wcast-qual" +#pragma clang diagnostic ignored "-Wctad-maybe-unsupported" +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wignored-qualifiers" +#pragma clang diagnostic ignored "-Wvla-extension" +#pragma clang diagnostic ignored "-Wcovered-switch-default" +#pragma clang diagnostic ignored "-Wsign-compare" +#pragma clang diagnostic ignored "-Wc++20-extensions" +#pragma clang diagnostic ignored "-Wc++20-designator" +#pragma clang diagnostic ignored "-Wnon-virtual-dtor" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunknown-warning-option" +#pragma clang diagnostic ignored "-Wsuggest-override" +#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" +#pragma clang diagnostic ignored "-Wnested-anon-types" +#pragma clang diagnostic ignored "-Wreorder-ctor" +#pragma clang diagnostic ignored "-Wmismatched-tags" +#pragma clang diagnostic ignored "-Wunused-lambda-capture" +#pragma clang diagnostic ignored "-Wmissing-field-initializers" +#pragma clang diagnostic ignored "-Wunused-private-field" +#pragma clang diagnostic ignored "-Wimplicit-fallthrough" +#pragma clang diagnostic ignored "-Wstring-conversion" +#pragma clang diagnostic ignored "-Wunneeded-internal-declaration" +#pragma clang diagnostic ignored "-Wunused-local-typedef" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wpessimizing-move" +#pragma clang diagnostic ignored "-Wparentheses" +#pragma clang diagnostic ignored "-Wdeprecated-volatile" +#pragma clang diagnostic ignored "-Wdeprecated-this-capture" +#pragma clang diagnostic ignored "-Wc++23-extensions" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" +#pragma clang diagnostic ignored "-Wlogical-op-parentheses" +#pragma clang diagnostic ignored "-Wundefined-inline" +#pragma clang diagnostic ignored "-Wc99-extensions" +#pragma clang diagnostic ignored "-Wc++11-narrowing" +#pragma clang diagnostic ignored "-Wzero-length-array" +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +#define FMT_HEADER_ONLY + +#include "tt_metal/common/core_coord.hpp" +#include "tt_metal/impl/buffers/buffer.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/types.hpp" + +#pragma clang diagnostic pop + +#endif // TTMLIR_OPMODEL_TTNN_TTNNOPMODELLIB_IMPL_H diff --git a/lib/Scheduler/Scheduler.cpp b/lib/Scheduler/Scheduler.cpp index 25923fffdf..52066c5e87 100644 --- a/lib/Scheduler/Scheduler.cpp +++ b/lib/Scheduler/Scheduler.cpp @@ -12,7 +12,8 @@ namespace mlir::tt::scheduler { -bool isTTNNOp(mlir::Operation *op) { +// TTNN op is scheduleable if it is not an EmptyOp and has at least one result. +bool isTTNNScheduleableOp(mlir::Operation *op) { return isa(op->getDialect()) && op->getNumResults() > 0 && !llvm::isa(op); } @@ -21,8 +22,8 @@ bool isTTIROp(mlir::Operation *op) { return isa(op->getDialect()); } -bool isTTShedulableOp(mlir::Operation *op) { - return isTTNNOp(op) || isTTIROp(op); +bool Scheduler::isTTShedulableOp(mlir::Operation *op) { + return isTTNNScheduleableOp(op) || isTTIROp(op); } // Init the dependencies map of all ops which are TTIR ops diff --git a/lib/SharedLib/CMakeLists.txt b/lib/SharedLib/CMakeLists.txt index e743cb0b06..7f32c8aa66 100644 --- a/lib/SharedLib/CMakeLists.txt +++ b/lib/SharedLib/CMakeLists.txt @@ -2,7 +2,7 @@ set(TTNN_RUNTIME_LIBS TTRuntime TTRuntimeTTNN TTBinary) # Dependency libs from tt-metal/ttnn project for ttnn runtime -set(TTNN_LIBS TTMETAL_LIBRARY TTNN_LIBRARY) +set(TTNN_LIBS TTMETAL_LIBRARY DEVICE_LIBRARY TTNN_LIBRARY) if (TT_RUNTIME_ENABLE_PERF_TRACE) list(APPEND TTNN_LIBS TRACY_LIBRARY) endif() diff --git a/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp b/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp index 47e15accf6..e82deaf633 100644 --- a/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp +++ b/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp @@ -62,18 +62,18 @@ memrefAttrToFlatbuffer(FlatbufferObjectCache &cache, MemRefType memref, toFlatbuffer(cache, memLayout), size); } -flatbuffers::Offset<::tt::target::LayoutDesc> -layoutAttrToFlatbuffer(FlatbufferObjectCache &cache, LayoutAttr layoutAttr, - ArrayRef logicalShape, DeviceAttr deviceAttr) { - auto strideInt64 = layoutAttr.getStride(logicalShape); +flatbuffers::Offset<::tt::target::LayoutDesc> metalLayoutAttrToFlatbuffer( + FlatbufferObjectCache &cache, MetalLayoutAttr metalLayoutAttr, + ArrayRef logicalShape, DeviceAttr deviceAttr) { + auto strideInt64 = metalLayoutAttr.getStride(logicalShape); std::vector stride(strideInt64.begin(), strideInt64.end()); - auto coreRangeSet = - toFlatbuffer(cache, layoutAttr.getGrid(), deviceAttr.getWorkerGrid()); + auto coreRangeSet = toFlatbuffer(cache, metalLayoutAttr.getGrid(), + deviceAttr.getWorkerGrid()); return ::tt::target::CreateLayoutDescDirect( - *cache.fbb, &stride, toFlatbuffer(cache, layoutAttr.getOobVal()), + *cache.fbb, &stride, toFlatbuffer(cache, metalLayoutAttr.getOobVal()), &coreRangeSet, - cache.getOrCreate(layoutAttr.getMemref(), memrefAttrToFlatbuffer, - layoutAttr.getMemLayout())); + cache.getOrCreate(metalLayoutAttr.getMemref(), memrefAttrToFlatbuffer, + metalLayoutAttr.getMemLayout())); } } // namespace mlir::tt @@ -277,7 +277,7 @@ static std::shared_ptr translateModuleToFlatbuffer( argumentAllocations[input.getArgNumber()]); assert( argAlloc.getMemorySpace() == - mlir::cast( + mlir::cast( mlir::cast(input.getType()).getEncoding()) .getMemorySpace() && "argument allocation memory space does not match tensor type " diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 30b83014d4..d0d65ad874 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -28,6 +28,7 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Support/LogicalResult.h" +#include "types_generated.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -39,8 +40,13 @@ namespace mlir::tt { ::tt::target::TensorMemoryLayout -toFlatbuffer(FlatbufferObjectCache &, ttnn::TensorMemoryLayout memLayout) { - switch (memLayout) { +toFlatbuffer(FlatbufferObjectCache &, + ttnn::TensorMemoryLayoutAttr memLayoutAttr) { + if (!memLayoutAttr) { + return ::tt::target::TensorMemoryLayout::None; + } + + switch (memLayoutAttr.getValue()) { case ttnn::TensorMemoryLayout::SingleBank: return ::tt::target::TensorMemoryLayout::SingleBank; case ttnn::TensorMemoryLayout::Interleaved: @@ -51,8 +57,6 @@ toFlatbuffer(FlatbufferObjectCache &, ttnn::TensorMemoryLayout memLayout) { return ::tt::target::TensorMemoryLayout::WidthSharded; case ttnn::TensorMemoryLayout::BlockSharded: return ::tt::target::TensorMemoryLayout::BlockSharded; - case ttnn::TensorMemoryLayout::None: - return ::tt::target::TensorMemoryLayout::None; } } @@ -72,7 +76,7 @@ ::tt::target::MemorySpace toFlatbuffer(FlatbufferObjectCache &, flatbuffers::Offset<::tt::target::MemoryDesc> memrefAttrToFlatbuffer(FlatbufferObjectCache &cache, mlir::MemRefType memref, - ttnn::TensorMemoryLayout memLayout) { + ttnn::TensorMemoryLayoutAttr memLayoutAttr) { auto shapeInt64 = memref.getShape(); std::vector shape(shapeInt64.begin(), shapeInt64.end()); DataType dtype = DataType::Float32; @@ -99,7 +103,7 @@ memrefAttrToFlatbuffer(FlatbufferObjectCache &cache, mlir::MemRefType memref, toFlatbuffer( cache, mlir::cast(memref.getMemorySpace()).getValue()), - toFlatbuffer(cache, memLayout), size); + toFlatbuffer(cache, memLayoutAttr), size); } flatbuffers::Offset<::tt::target::LayoutDesc> ttnnLayoutAttrToFlatbuffer( @@ -162,10 +166,10 @@ createDeviceRef(FlatbufferObjectCache &cache, Value device) { template ::flatbuffers::Offset<::tt::target::ttnn::Operation> createOperation(FlatbufferObjectCache &cache, ::flatbuffers::Offset op, - std::string const &debugString) { + std::string const &debugString, std::string const &locInfo) { return CreateOperationDirect( *cache.fbb, ::tt::target::ttnn::OpTypeTraits::enum_value, op.Union(), - debugString.c_str()); + debugString.c_str(), locInfo.c_str()); } ::flatbuffers::Offset<::tt::target::ttnn::GetDeviceOp> @@ -333,6 +337,46 @@ createOp(FlatbufferObjectCache &cache, FullOp op) { kHostAllocatedSize)); } +::flatbuffers::Offset<::tt::target::ttnn::ArangeOp> +createOp(FlatbufferObjectCache &cache, ArangeOp op) { + + std::optional<::tt::target::DataType> dtype = + op.getDtype().has_value() + ? std::make_optional(toFlatbuffer(cache, op.getDtype().value())) + : std::nullopt; + auto device = + op.getDevice() ? cache.at<::tt::target::DeviceRef>(op.getDevice()) : 0; + + auto memoryConfigDesc = op.getMemoryConfig().has_value() + ? cache.getOrCreate(op.getMemoryConfig().value(), + memoryConfigToFlatbuffer) + : 0; + + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedAddress, kHostAllocatedSize); + + return ::tt::target::ttnn::CreateArangeOp( + *cache.fbb, static_cast(op.getStart()), + static_cast(op.getEnd()), static_cast(op.getStep()), + dtype /* optional */, device /* optional */, + memoryConfigDesc /* optional */, output); +} + +::flatbuffers::Offset<::tt::target::ttnn::LinearOp> +createOp(FlatbufferObjectCache &cache, LinearOp op) { + auto in0 = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getA())); + auto in1 = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getB())); + auto bias = op.getODSOperands(2).empty() + ? flatbuffers::Offset<::tt::target::TensorRef>() + : cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getBias())); + auto output = cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getResult())); + return ::tt::target::ttnn::CreateLinearOp(*cache.fbb, in0, in1, bias, output); +} + // ANCHOR: adding_an_op_matmul_serialize_to_binary ::flatbuffers::Offset<::tt::target::ttnn::MatmulOp> createOp(FlatbufferObjectCache &cache, MatmulOp op) { @@ -395,6 +439,30 @@ createEltwiseOpParams(FlatbufferObjectCache &cache, EltwiseOp op) { } } +::flatbuffers::Offset<::tt::target::ttnn::UpdateCacheOp> +createOp(FlatbufferObjectCache &cache, UpdateCacheOp op) { + auto cacheOperand = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getCache())); + auto input = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); + auto updateIndex = cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getUpdateIndex())); + + return ::tt::target::ttnn::CreateUpdateCacheOp( + *cache.fbb, cacheOperand, input, updateIndex, op.getBatchOffset()); +} + +::flatbuffers::Offset<::tt::target::ttnn::FillCacheOp> +createOp(FlatbufferObjectCache &cache, FillCacheOp op) { + auto cacheOperand = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getCache())); + auto input = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); + + return ::tt::target::ttnn::CreateFillCacheOp(*cache.fbb, cacheOperand, input, + op.getBatchOffset()); +} + template ::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp> createNonDPSEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { @@ -485,6 +553,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::Div; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Sigmoid; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Scatter; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Log1p; } else if constexpr (std::is_same_v) { @@ -513,6 +583,10 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { ::tt::target::ttnn::EltwiseOpWithFloatParams>( cache, op) .Union(); + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Tan; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Tanh; } else { llvm_unreachable("unhandled EltwiseOp"); } @@ -554,7 +628,6 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) { dim_arg, op.getKeepDim()); } -template ::flatbuffers::Offset<::tt::target::ttnn::TransposeOp> createTransposeOp(FlatbufferObjectCache &cache, TransposeOp op) { auto in = @@ -567,7 +640,6 @@ createTransposeOp(FlatbufferObjectCache &cache, TransposeOp op) { return ::tt::target::ttnn::CreateTransposeOp(*cache.fbb, in, out, dim0, dim1); } -template ::flatbuffers::Offset<::tt::target::ttnn::ConcatOp> createConcatOp(FlatbufferObjectCache &cache, ConcatOp op) { std::vector<::flatbuffers::Offset<::tt::target::TensorRef>> ins; @@ -582,7 +654,6 @@ createConcatOp(FlatbufferObjectCache &cache, ConcatOp op) { return ::tt::target::ttnn::CreateConcatOpDirect(*cache.fbb, &ins, out, dim); } -template ::flatbuffers::Offset<::tt::target::ttnn::EmbeddingOp> createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) { auto in0 = @@ -594,7 +665,6 @@ createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) { return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output); } -template ::flatbuffers::Offset<::tt::target::ttnn::ReshapeOp> createReshapeOp(FlatbufferObjectCache &cache, ReshapeOp op) { auto in = @@ -607,7 +677,6 @@ createReshapeOp(FlatbufferObjectCache &cache, ReshapeOp op) { return ::tt::target::ttnn::CreateReshapeOp(*cache.fbb, in, out, shape); } -template ::flatbuffers::Offset<::tt::target::ttnn::SliceOp> createSliceOp(FlatbufferObjectCache &cache, SliceOp op) { auto in = @@ -625,7 +694,6 @@ createSliceOp(FlatbufferObjectCache &cache, SliceOp op) { step); } -template ::flatbuffers::Offset<::tt::target::ttnn::MaxPool2dOp> createMaxPool2dOp(FlatbufferObjectCache &cache, MaxPool2dOp op) { auto in = @@ -643,7 +711,6 @@ createMaxPool2dOp(FlatbufferObjectCache &cache, MaxPool2dOp op) { op.getPaddingWidth()); } -template ::flatbuffers::Offset<::tt::target::ttnn::SoftmaxOp> createSoftmaxOp(FlatbufferObjectCache &cache, SoftmaxOp op) { auto in = @@ -655,7 +722,6 @@ createSoftmaxOp(FlatbufferObjectCache &cache, SoftmaxOp op) { return ::tt::target::ttnn::CreateSoftmaxOp(*cache.fbb, in, out, dimension); } -template ::flatbuffers::Offset<::tt::target::ttnn::DeallocateOp> createDeallocateOp(FlatbufferObjectCache &cache, DeallocateOp op) { auto in = @@ -666,208 +732,279 @@ createDeallocateOp(FlatbufferObjectCache &cache, DeallocateOp op) { ::flatbuffers::Offset<::tt::target::ttnn::Operation> emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, - std::string const &debugString) { + std::string const &debugString, std::string const &locInfo) { if (auto getDeviceOp = dyn_cast(op); getDeviceOp) { - return createOperation(cache, createOp(cache, getDeviceOp), debugString); + return createOperation(cache, createOp(cache, getDeviceOp), debugString, + locInfo); } if (auto toMemoryConfigOp = dyn_cast(op); toMemoryConfigOp) { return createOperation(cache, createOp(cache, toMemoryConfigOp), - debugString); + debugString, locInfo); } if (auto toLayoutOp = dyn_cast(op); toLayoutOp) { - return createOperation(cache, createOp(cache, toLayoutOp), debugString); + return createOperation(cache, createOp(cache, toLayoutOp), debugString, + locInfo); } if (auto typecastOp = dyn_cast(op); typecastOp) { - return createOperation(cache, createOp(cache, typecastOp), debugString); + return createOperation(cache, createOp(cache, typecastOp), debugString, + locInfo); } if (auto toDeviceOp = dyn_cast(op); toDeviceOp) { - return createOperation(cache, createOp(cache, toDeviceOp), debugString); + return createOperation(cache, createOp(cache, toDeviceOp), debugString, + locInfo); } if (auto fromDeviceOp = dyn_cast(op); fromDeviceOp) { - return createOperation(cache, createOp(cache, fromDeviceOp), debugString); + return createOperation(cache, createOp(cache, fromDeviceOp), debugString, + locInfo); } if (auto emptyOp = dyn_cast(op); emptyOp) { - return createOperation(cache, createOp(cache, emptyOp), debugString); + return createOperation(cache, createOp(cache, emptyOp), debugString, + locInfo); } if (auto fullOp = dyn_cast(op); fullOp) { - return createOperation(cache, createOp(cache, fullOp), debugString); + return createOperation(cache, createOp(cache, fullOp), debugString, + locInfo); } if (auto absOp = dyn_cast(op); absOp) { - return createOperation(cache, createEltwiseOp(cache, absOp), debugString); + return createOperation(cache, createEltwiseOp(cache, absOp), debugString, + locInfo); } if (auto addOp = dyn_cast(op); addOp) { - return createOperation(cache, createEltwiseOp(cache, addOp), debugString); + return createOperation(cache, createEltwiseOp(cache, addOp), debugString, + locInfo); } if (auto floorOp = dyn_cast(op); floorOp) { - return createOperation(cache, createEltwiseOp(cache, floorOp), debugString); + return createOperation(cache, createEltwiseOp(cache, floorOp), debugString, + locInfo); } if (auto isFiniteOp = dyn_cast(op); isFiniteOp) { return createOperation(cache, createEltwiseOp(cache, isFiniteOp), - debugString); + debugString, locInfo); } if (auto andOp = dyn_cast(op); andOp) { - return createOperation(cache, createEltwiseOp(cache, andOp), debugString); + return createOperation(cache, createEltwiseOp(cache, andOp), debugString, + locInfo); } if (auto cbrtOp = dyn_cast(op); cbrtOp) { - return createOperation(cache, createEltwiseOp(cache, cbrtOp), debugString); + return createOperation(cache, createEltwiseOp(cache, cbrtOp), debugString, + locInfo); } if (auto notOp = dyn_cast(op); notOp) { - return createOperation(cache, createEltwiseOp(cache, notOp), debugString); + return createOperation(cache, createEltwiseOp(cache, notOp), debugString, + locInfo); } if (auto orOp = dyn_cast(op); orOp) { - return createOperation(cache, createEltwiseOp(cache, orOp), debugString); + return createOperation(cache, createEltwiseOp(cache, orOp), debugString, + locInfo); } if (auto xorOp = dyn_cast(op); xorOp) { - return createOperation(cache, createEltwiseOp(cache, xorOp), debugString); + return createOperation(cache, createEltwiseOp(cache, xorOp), debugString, + locInfo); } if (auto multiplyOp = dyn_cast(op); multiplyOp) { return createOperation(cache, createEltwiseOp(cache, multiplyOp), - debugString); + debugString, locInfo); } if (auto negOp = dyn_cast(op); negOp) { - return createOperation(cache, createEltwiseOp(cache, negOp), debugString); + return createOperation(cache, createEltwiseOp(cache, negOp), debugString, + locInfo); } if (auto subtractOp = dyn_cast(op); subtractOp) { return createOperation(cache, createEltwiseOp(cache, subtractOp), - debugString); + debugString, locInfo); } if (auto eqOp = dyn_cast(op); eqOp) { - return createOperation(cache, createEltwiseOp(cache, eqOp), debugString); + return createOperation(cache, createEltwiseOp(cache, eqOp), debugString, + locInfo); } if (auto neOp = dyn_cast(op); neOp) { - return createOperation(cache, createEltwiseOp(cache, neOp), debugString); + return createOperation(cache, createEltwiseOp(cache, neOp), debugString, + locInfo); } if (auto geOp = dyn_cast(op); geOp) { - return createOperation(cache, createEltwiseOp(cache, geOp), debugString); + return createOperation(cache, createEltwiseOp(cache, geOp), debugString, + locInfo); } if (auto gtOp = dyn_cast(op); gtOp) { - return createOperation(cache, createEltwiseOp(cache, gtOp), debugString); + return createOperation(cache, createEltwiseOp(cache, gtOp), debugString, + locInfo); } if (auto leOp = dyn_cast(op); leOp) { - return createOperation(cache, createEltwiseOp(cache, leOp), debugString); + return createOperation(cache, createEltwiseOp(cache, leOp), debugString, + locInfo); } if (auto ltOp = dyn_cast(op); ltOp) { - return createOperation(cache, createEltwiseOp(cache, ltOp), debugString); + return createOperation(cache, createEltwiseOp(cache, ltOp), debugString, + locInfo); } if (auto maximumOp = dyn_cast(op); maximumOp) { return createOperation(cache, createEltwiseOp(cache, maximumOp), - debugString); + debugString, locInfo); } if (auto minimumOp = dyn_cast(op); minimumOp) { return createOperation(cache, createEltwiseOp(cache, minimumOp), - debugString); + debugString, locInfo); } if (auto reluOp = dyn_cast(op); reluOp) { - return createOperation(cache, createEltwiseOp(cache, reluOp), debugString); + return createOperation(cache, createEltwiseOp(cache, reluOp), debugString, + locInfo); } if (auto sqrtOp = dyn_cast(op); sqrtOp) { - return createOperation(cache, createEltwiseOp(cache, sqrtOp), debugString); + return createOperation(cache, createEltwiseOp(cache, sqrtOp), debugString, + locInfo); } if (auto rsqrtOp = dyn_cast(op); rsqrtOp) { - return createOperation(cache, createEltwiseOp(cache, rsqrtOp), debugString); + return createOperation(cache, createEltwiseOp(cache, rsqrtOp), debugString, + locInfo); } if (auto signOp = dyn_cast(op); signOp) { - return createOperation(cache, createEltwiseOp(cache, signOp), debugString); + return createOperation(cache, createEltwiseOp(cache, signOp), debugString, + locInfo); } if (auto expOp = dyn_cast(op); expOp) { - return createOperation(cache, createEltwiseOp(cache, expOp), debugString); + return createOperation(cache, createEltwiseOp(cache, expOp), debugString, + locInfo); } if (auto logOp = dyn_cast(op); logOp) { - return createOperation(cache, createEltwiseOp(cache, logOp), debugString); + return createOperation(cache, createEltwiseOp(cache, logOp), debugString, + locInfo); } if (auto expm1Op = dyn_cast(op); expm1Op) { - return createOperation(cache, createEltwiseOp(cache, expm1Op), debugString); + return createOperation(cache, createEltwiseOp(cache, expm1Op), debugString, + locInfo); } if (auto sigmoidOp = dyn_cast(op); sigmoidOp) { return createOperation(cache, createEltwiseOp(cache, sigmoidOp), - debugString); + debugString, locInfo); } if (auto log1pOp = dyn_cast(op); log1pOp) { - return createOperation(cache, createEltwiseOp(cache, log1pOp), debugString); + return createOperation(cache, createEltwiseOp(cache, log1pOp), debugString, + locInfo); + } + if (auto scatterOp = dyn_cast(op); scatterOp) { + return createOperation(cache, createEltwiseOp(cache, scatterOp), + debugString, locInfo); } if (auto reciprocalOp = dyn_cast(op); reciprocalOp) { return createOperation(cache, createEltwiseOp(cache, reciprocalOp), - debugString); + debugString, locInfo); } if (auto divOp = dyn_cast(op); divOp) { - return createOperation(cache, createEltwiseOp(cache, divOp), debugString); + return createOperation(cache, createEltwiseOp(cache, divOp), debugString, + locInfo); } if (auto remainderOp = dyn_cast(op); remainderOp) { return createOperation(cache, createEltwiseOp(cache, remainderOp), - debugString); + debugString, locInfo); } if (auto leakyReluOp = dyn_cast(op); leakyReluOp) { return createOperation(cache, createEltwiseOp(cache, leakyReluOp), - debugString); + debugString, locInfo); + } + if (auto linearOp = dyn_cast(op); linearOp) { + return createOperation(cache, createOp(cache, linearOp), debugString, + locInfo); } if (auto matmulOp = dyn_cast(op); matmulOp) { - return createOperation(cache, createOp(cache, matmulOp), debugString); + return createOperation(cache, createOp(cache, matmulOp), debugString, + locInfo); } if (auto sumOp = dyn_cast(op); sumOp) { - return createOperation(cache, createReductionOp(cache, sumOp), debugString); + return createOperation(cache, createReductionOp(cache, sumOp), debugString, + locInfo); } if (auto meanOp = dyn_cast(op); meanOp) { - return createOperation(cache, createReductionOp(cache, meanOp), - debugString); + return createOperation(cache, createReductionOp(cache, meanOp), debugString, + locInfo); } if (auto maxOp = dyn_cast(op); maxOp) { - return createOperation(cache, createReductionOp(cache, maxOp), debugString); + return createOperation(cache, createReductionOp(cache, maxOp), debugString, + locInfo); } if (auto embeddingOp = dyn_cast(op); embeddingOp) { return createOperation(cache, createEmbeddingOp(cache, embeddingOp), - debugString); + debugString, locInfo); } if (auto softmaxOp = dyn_cast(op); softmaxOp) { return createOperation(cache, createSoftmaxOp(cache, softmaxOp), - debugString); + debugString, locInfo); } if (auto transposeOp = dyn_cast(op); transposeOp) { return createOperation(cache, createTransposeOp(cache, transposeOp), - debugString); + debugString, locInfo); } if (auto clampOp = dyn_cast(op); clampOp) { return createOperation(cache, createNonDPSEltwiseOp(cache, clampOp), - debugString); + debugString, locInfo); } if (auto conv2dOp = dyn_cast(op); conv2dOp) { - return createOperation(cache, createOp(cache, conv2dOp), debugString); + return createOperation(cache, createOp(cache, conv2dOp), debugString, + locInfo); } if (auto allGatherOp = dyn_cast(op); allGatherOp) { - return createOperation(cache, createOp(cache, allGatherOp), debugString); + return createOperation(cache, createOp(cache, allGatherOp), debugString, + locInfo); } if (auto concatOp = dyn_cast(op); concatOp) { - return createOperation(cache, createConcatOp(cache, concatOp), debugString); + return createOperation(cache, createConcatOp(cache, concatOp), debugString, + locInfo); } if (auto reshapeOp = dyn_cast(op); reshapeOp) { return createOperation(cache, createReshapeOp(cache, reshapeOp), - debugString); + debugString, locInfo); } if (auto sliceOp = dyn_cast(op); sliceOp) { - return createOperation(cache, createSliceOp(cache, sliceOp), debugString); + return createOperation(cache, createSliceOp(cache, sliceOp), debugString, + locInfo); } if (auto max_pool2dOp = dyn_cast(op); max_pool2dOp) { return createOperation(cache, createMaxPool2dOp(cache, max_pool2dOp), - debugString); + debugString, locInfo); } if (auto deallocateOp = dyn_cast(op); deallocateOp) { return createOperation(cache, createDeallocateOp(cache, deallocateOp), - debugString); + debugString, locInfo); } if (auto ceilOp = dyn_cast(op); ceilOp) { - return createOperation(cache, createEltwiseOp(cache, ceilOp), debugString); + return createOperation(cache, createEltwiseOp(cache, ceilOp), debugString, + locInfo); } if (auto cosOp = dyn_cast(op); cosOp) { - return createOperation(cache, createEltwiseOp(cache, cosOp), debugString); + return createOperation(cache, createEltwiseOp(cache, cosOp), debugString, + locInfo); } if (auto sinOp = dyn_cast(op); sinOp) { - return createOperation(cache, createEltwiseOp(cache, sinOp), debugString); + return createOperation(cache, createEltwiseOp(cache, sinOp), debugString, + locInfo); } if (auto whereOp = dyn_cast(op); whereOp) { - return createOperation(cache, createEltwiseOp(cache, whereOp), debugString); + return createOperation(cache, createEltwiseOp(cache, whereOp), debugString, + locInfo); } if (auto geluOp = dyn_cast(op); geluOp) { - return createOperation(cache, createEltwiseOp(cache, geluOp), debugString); + return createOperation(cache, createEltwiseOp(cache, geluOp), debugString, + locInfo); + } + if (auto arangeOp = dyn_cast(op); arangeOp) { + return createOperation(cache, createOp(cache, arangeOp), debugString, + locInfo); + } + if (auto tanOp = dyn_cast(op); tanOp) { + return createOperation(cache, createEltwiseOp(cache, tanOp), debugString, + locInfo); + } + if (auto tanhOp = dyn_cast(op); tanhOp) { + return createOperation(cache, createEltwiseOp(cache, tanhOp), debugString, + locInfo); + } + if (auto updateCacheOp = dyn_cast(op); updateCacheOp) { + return createOperation(cache, createOp(cache, updateCacheOp), debugString, + locInfo); + } + if (auto fillCacheOp = dyn_cast(op); fillCacheOp) { + return createOperation(cache, createOp(cache, fillCacheOp), debugString, + locInfo); } llvm_unreachable("unhandled op in emitTTNNOperation"); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index e43cb858d4..cbfc3bf95f 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -63,6 +63,12 @@ declare_mlir_python_sources(TTMLIRPythonSources.Overrides SOURCES overrides.py ) +declare_mlir_python_sources(TTMLIRPythonSources.OptimizerOverrides + ROOT_DIR "${TTMLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TTMLIRPythonSources + SOURCES optimizer_overrides.py +) + declare_mlir_python_sources(TTMLIRPythonSources.Passes ROOT_DIR "${TTMLIR_PYTHON_ROOT_DIR}" ADD_TO_PARENT TTMLIRPythonSources @@ -87,6 +93,7 @@ declare_mlir_python_extension(TTMLIRPythonExtensions.Main TTKernelModule.cpp TTNNModule.cpp Overrides.cpp + OptimizerOverrides.cpp Passes.cpp EMBED_CAPI_LINK_LIBS MLIRCAPITransforms diff --git a/python/OptimizerOverrides.cpp b/python/OptimizerOverrides.cpp new file mode 100644 index 0000000000..bd5ce94f43 --- /dev/null +++ b/python/OptimizerOverrides.cpp @@ -0,0 +1,153 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" +#include "ttmlir/Bindings/Python/TTMLIRModule.h" + +namespace mlir::ttmlir::python { + +void populateOptimizerOverridesModule(py::module &m) { + + py::class_(m, + "OptimizerOverridesHandler") + .def(py::init<>()) + + .def("set_enable_optimizer", + &tt::ttnn::OptimizerOverridesHandler::setEnableOptimizer) + .def("get_enable_optimizer", + &tt::ttnn::OptimizerOverridesHandler::getEnableOptimizer) + + .def("set_memory_reconfig", + &tt::ttnn::OptimizerOverridesHandler::setMemoryReconfig) + .def("get_memory_reconfig", + &tt::ttnn::OptimizerOverridesHandler::getMemoryReconfig) + + .def("set_enable_memory_layout_analysis", + &tt::ttnn::OptimizerOverridesHandler::setEnableMemoryLayoutAnalysis) + .def("get_enable_memory_layout_analysis", + &tt::ttnn::OptimizerOverridesHandler::getEnableMemoryLayoutAnalysis) + + .def("set_enable_memory_layout_analysis_policy", + &tt::ttnn::OptimizerOverridesHandler:: + setEnableMemoryLayoutAnalysisPolicy) + .def("get_enable_memory_layout_analysis_policy", + &tt::ttnn::OptimizerOverridesHandler:: + getEnableMemoryLayoutAnalysisPolicy) + + .def("set_memory_layout_analysis_policy", + &tt::ttnn::OptimizerOverridesHandler::setMemoryLayoutAnalysisPolicy) + .def("get_memory_layout_analysis_policy", + &tt::ttnn::OptimizerOverridesHandler::getMemoryLayoutAnalysisPolicy) + + .def("set_system_desc_path", + &tt::ttnn::OptimizerOverridesHandler::setSystemDescPath) + .def("get_system_desc_path", + &tt::ttnn::OptimizerOverridesHandler::getSystemDescPath) + + .def("set_max_legal_layouts", + &tt::ttnn::OptimizerOverridesHandler::setMaxLegalLayouts) + .def("get_max_legal_layouts", + &tt::ttnn::OptimizerOverridesHandler::getMaxLegalLayouts) + + .def("set_mesh_shape", &tt::ttnn::OptimizerOverridesHandler::setMeshShape) + .def("get_mesh_shape", &tt::ttnn::OptimizerOverridesHandler::getMeshShape) + + .def("get_input_layout_overrides", + &tt::ttnn::OptimizerOverridesHandler:: + getInputLayoutOverridesPybindWrapper) + .def("get_output_layout_overrides", + &tt::ttnn::OptimizerOverridesHandler:: + getOutputLayoutOverridesPybindWrapper) + + .def("add_input_layout_override", &tt::ttnn::OptimizerOverridesHandler:: + addInputLayoutOverridePybindWrapper) + .def("add_output_layout_override", + &tt::ttnn::OptimizerOverridesHandler:: + addOutputLayoutOverridePybindWrapper) + + .def("to_string", &tt::ttnn::OptimizerOverridesHandler::toString); + + py::enum_( + m, "MemoryLayoutAnalysisPolicyType") + .value("DFSharding", mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding) + .value("L1Interleaved", + mlir::tt::MemoryLayoutAnalysisPolicyType::L1Interleaved); + + py::enum_(m, "BufferType") + .value("DRAM", mlir::tt::ttnn::BufferType::DRAM) + .value("L1", mlir::tt::ttnn::BufferType::L1) + .value("SystemMemory", mlir::tt::ttnn::BufferType::SystemMemory) + .value("L1Small", mlir::tt::ttnn::BufferType::L1Small) + .value("Trace", mlir::tt::ttnn::BufferType::Trace); + + py::enum_(m, "Layout") + .value("RowMajor", mlir::tt::ttnn::Layout::RowMajor) + .value("Tile", mlir::tt::ttnn::Layout::Tile) + .value("Invalid", mlir::tt::ttnn::Layout::Invalid); + + py::enum_(m, "TensorMemoryLayout") + .value("Interleaved", mlir::tt::ttnn::TensorMemoryLayout::Interleaved) + .value("SingleBank", mlir::tt::ttnn::TensorMemoryLayout::SingleBank) + .value("HeightSharded", mlir::tt::ttnn::TensorMemoryLayout::HeightSharded) + .value("WidthSharded", mlir::tt::ttnn::TensorMemoryLayout::WidthSharded) + .value("BlockSharded", mlir::tt::ttnn::TensorMemoryLayout::BlockSharded); + + py::enum_(m, "DataType") + .value("Float32", mlir::tt::DataType::Float32) + .value("Float16", mlir::tt::DataType::Float16) + .value("BFloat16", mlir::tt::DataType::BFloat16) + .value("BFP_Float8", mlir::tt::DataType::BFP_Float8) + .value("BFP_BFloat8", mlir::tt::DataType::BFP_BFloat8) + .value("BFP_Float4", mlir::tt::DataType::BFP_Float4) + .value("BFP_BFloat4", mlir::tt::DataType::BFP_BFloat4) + .value("BFP_Float2", mlir::tt::DataType::BFP_Float2) + .value("BFP_BFloat2", mlir::tt::DataType::BFP_BFloat2) + .value("UInt32", mlir::tt::DataType::UInt32) + .value("UInt16", mlir::tt::DataType::UInt16) + .value("UInt8", mlir::tt::DataType::UInt8); + + py::class_( + m, "InputLayoutOverrideParams") + .def(py::init<>()) + .def_property( + "operand_idxes", + [](const mlir::tt::ttnn::InputLayoutOverrideParams &obj) { + // Getter: Convert SmallVector to std::vector + return std::vector(obj.operandIdxes.begin(), + obj.operandIdxes.end()); + }, + [](mlir::tt::ttnn::InputLayoutOverrideParams &obj, + const std::vector &input) { + // Setter: Convert std::vector to SmallVector + obj.operandIdxes.clear(); + obj.operandIdxes.append(input.begin(), input.end()); + }); + + py::class_( + m, "OutputLayoutOverrideParams") + .def(py::init<>()) + .def_property( + "grid", + [](const mlir::tt::ttnn::OutputLayoutOverrideParams &obj) { + // Getter: Convert SmallVector to std::vector + return std::vector(obj.grid->begin(), obj.grid->end()); + }, + [](mlir::tt::ttnn::OutputLayoutOverrideParams &obj, + const std::vector &input) { + // Setter: Convert std::vector to SmallVector + obj.grid->clear(); + obj.grid->append(input.begin(), input.end()); + }) + .def_readwrite("buffer_type", + &mlir::tt::ttnn::OutputLayoutOverrideParams::bufferType) + .def_readwrite( + "tensor_memory_layout", + &mlir::tt::ttnn::OutputLayoutOverrideParams::tensorMemoryLayout) + .def_readwrite("memory_layout", + &mlir::tt::ttnn::OutputLayoutOverrideParams::memoryLayout) + .def_readwrite("data_type", + &mlir::tt::ttnn::OutputLayoutOverrideParams::dataType); +} + +} // namespace mlir::ttmlir::python diff --git a/python/TTMLIRModule.cpp b/python/TTMLIRModule.cpp index 9c4a4c81b5..0347da75b5 100644 --- a/python/TTMLIRModule.cpp +++ b/python/TTMLIRModule.cpp @@ -40,4 +40,7 @@ PYBIND11_MODULE(_ttmlir, m) { auto passes = m.def_submodule("passes", "Python-Bound Passes & Transformations"); mlir::ttmlir::python::populatePassesModule(passes); + auto optimizer_overrides = m.def_submodule( + "optimizer_overrides", "Python-Bound Optimizer Overrides"); + mlir::ttmlir::python::populateOptimizerOverridesModule(optimizer_overrides); } diff --git a/python/TTModule.cpp b/python/TTModule.cpp index f631b01169..b8d543410c 100644 --- a/python/TTModule.cpp +++ b/python/TTModule.cpp @@ -16,14 +16,14 @@ namespace mlir::ttmlir::python { void populateTTModule(py::module &m) { - tt_attribute_class(m, "LayoutAttr") + tt_attribute_class(m, "MetalLayoutAttr") .def_static("get", [](MlirContext ctx, MlirType rankedTensorType, uint32_t memorySpaceValue, MlirAttribute grid, std::vector> collapseIntervals, uint32_t oobValValue, uint32_t memLayoutValue) { - return wrap(tt::LayoutAttr::get( + return wrap(tt::MetalLayoutAttr::get( unwrap(ctx), mlir::cast(unwrap(rankedTensorType)), static_cast(memorySpaceValue), @@ -37,7 +37,7 @@ void populateTTModule(py::module &m) { std::vector> collapseIntervals) { return wrap( - mlir::cast(unwrap(self)) + mlir::cast(unwrap(self)) .withGrid(unwrap(ctx), tensorShape, mlir::cast(unwrap(grid)), collapseIntervals)); @@ -47,7 +47,7 @@ void populateTTModule(py::module &m) { std::vector tensorShape, MlirAttribute grid, std::vector> collapseIntervals) { - return mlir::cast(unwrap(self)) + return mlir::cast(unwrap(self)) .withGrid(unwrap(ctx), tensorShape, mlir::cast(unwrap(grid)), collapseIntervals); @@ -55,13 +55,13 @@ void populateTTModule(py::module &m) { .def_static( "with_element_type", [](MlirContext ctx, MlirAttribute self, MlirType elementType) { - return wrap(mlir::cast(unwrap(self)) + return wrap(mlir::cast(unwrap(self)) .withElementType(unwrap(ctx), unwrap(elementType))); }) .def_static( "with_element_type_", [](MlirContext ctx, MlirAttribute self, MlirType elementType) { - return mlir::cast(unwrap(self)) + return mlir::cast(unwrap(self)) .withElementType(unwrap(ctx), unwrap(elementType)); }) .def("getLayout", @@ -73,35 +73,45 @@ void populateTTModule(py::module &m) { mlir::cast(unwrap(type)); assert(tensor.getEncoding()); // Make sure that this Tensor has an // encoding value - tt::LayoutAttr layout = - mlir::cast(tensor.getEncoding()); + tt::MetalLayoutAttr layout = + mlir::cast(tensor.getEncoding()); return layout; }) - .def("wrapped", [](tt::LayoutAttr const &self) { return wrap(self); }) - .def_property_readonly( - "stride", - [](tt::LayoutAttr const &self, std::vector logicalShape) { - auto stride = self.getStride(logicalShape); - return std::vector(stride.begin(), stride.end()); - }) - .def_property_readonly("oobval", &tt::LayoutAttr::getOobVal) + .def("wrapped", + [](tt::MetalLayoutAttr const &self) { return wrap(self); }) + .def_property_readonly("stride", + [](tt::MetalLayoutAttr const &self, + std::vector logicalShape) { + auto stride = self.getStride(logicalShape); + return std::vector(stride.begin(), + stride.end()); + }) + .def_property_readonly("oobval", &tt::MetalLayoutAttr::getOobVal) .def_property_readonly("oobval_as_int", - [](tt::LayoutAttr la) { + [](tt::MetalLayoutAttr la) { return static_cast(la.getOobVal()); }) - .def_property_readonly("grid_attr", &tt::LayoutAttr::getGrid) - .def_property_readonly("memref", &tt::LayoutAttr::getMemref) - .def_property_readonly("memory_space", &tt::LayoutAttr::getMemorySpace) + .def_property_readonly("grid_attr", &tt::MetalLayoutAttr::getGrid) + .def_property_readonly( + "memref", + [](tt::MetalLayoutAttr self) { return wrap(self.getMemref()); }) + .def_property_readonly("memory_space", + &tt::MetalLayoutAttr::getMemorySpace) .def_property_readonly("memory_space_as_int", - [](tt::LayoutAttr la) { + [](tt::MetalLayoutAttr la) { return static_cast( la.getMemorySpace()); }) - .def_property_readonly("shard_shape", &tt::LayoutAttr::getShardShape) - .def_property_readonly("memory_layout", &tt::LayoutAttr::getMemLayout) - .def_property_readonly("memory_layout_as_int", [](tt::LayoutAttr la) { - return static_cast(la.getMemLayout()); - }); + .def_property_readonly("shard_shape", &tt::MetalLayoutAttr::getShardShape) + .def_property_readonly("memory_layout", + &tt::MetalLayoutAttr::getMemLayout) + .def_property_readonly( + "linear", + [](tt::MetalLayoutAttr self) { return wrap(self.getLinear()); }) + .def_property_readonly("memory_layout_as_int", + [](tt::MetalLayoutAttr la) { + return static_cast(la.getMemLayout()); + }); tt_attribute_class(m, "GridAttr") .def_static("get", @@ -236,6 +246,14 @@ void populateTTModule(py::module &m) { return self.getEthInactive().vec(); }); + tt_attribute_class(m, "CoreCoordAttr") + .def_static("get", + [](MlirContext ctx, int64_t y, int64_t x) { + return wrap(tt::CoreCoordAttr::get(unwrap(ctx), y, x)); + }) + .def_property_readonly("y", &tt::CoreCoordAttr::getY) + .def_property_readonly("x", &tt::CoreCoordAttr::getX); + tt_attribute_class(m, "ChipCoordAttr") .def_static("get", [](MlirContext ctx, unsigned rack, unsigned shelf, unsigned y, @@ -276,29 +294,29 @@ void populateTTModule(py::module &m) { }) .def_static( "get", - [](MlirContext ctx, std::vector cpuDescs, - std::vector chipDescs, - std::vector chipDescIndices, - std::vector chipCapabilities, - std::vector chipCoords, - std::vector chipChannels) { + [](MlirContext ctx, const std::vector &cpuDescs, + const std::vector &chipDescs, + const std::vector &chipDescIndices, + const std::vector &chipCapabilities, + const std::vector &chipCoords, + const std::vector &chipChannels) { std::vector chipDescsUnwrapped; - for (auto chipDesc : chipDescs) { + for (const auto &chipDesc : chipDescs) { chipDescsUnwrapped.push_back( mlir::cast(unwrap(chipDesc))); } std::vector chipCapabilitiesUnwrapped; - for (auto chipCapability : chipCapabilities) { + for (const auto &chipCapability : chipCapabilities) { chipCapabilitiesUnwrapped.push_back( mlir::cast(unwrap(chipCapability))); } std::vector chipCoordsUnwrapped; - for (auto chipCoord : chipCoords) { + for (const auto &chipCoord : chipCoords) { chipCoordsUnwrapped.push_back( mlir::cast(unwrap(chipCoord))); } std::vector chipChannelsUnwrapped; - for (auto chipChannel : chipChannels) { + for (const auto &chipChannel : chipChannels) { chipChannelsUnwrapped.push_back( mlir::cast(unwrap(chipChannel))); } @@ -430,8 +448,11 @@ void populateTTModule(py::module &m) { return mlir::cast(unwrap(self)); }) .def_property_readonly("grid_attr", &tt::DeviceAttr::getWorkerGrid) - .def_property_readonly("l1_map", &tt::DeviceAttr::getL1Map) - .def_property_readonly("dram_map", &tt::DeviceAttr::getDramMap) + .def_property_readonly( + "l1_map", [](tt::DeviceAttr self) { return wrap(self.getL1Map()); }) + .def_property_readonly( + "dram_map", + [](tt::DeviceAttr self) { return wrap(self.getDramMap()); }) .def_property_readonly( "mesh_shape", [](tt::DeviceAttr const &self) { return self.getMeshShape().vec(); }) @@ -447,7 +468,10 @@ void populateTTModule(py::module &m) { unwrap(ctx), SmallVector{height, width}, static_cast(dataType))); }) - .def_property_readonly("data_type", &tt::TileType::getDataType) + .def_property_readonly("data_type_as_int", + [](tt::TileType self) { + return static_cast(self.getDataType()); + }) .def_property_readonly("shape", [](tt::TileType const &tile) { return std::vector({tile.getHeight(), tile.getWidth()}); }); diff --git a/python/TTNNModule.cpp b/python/TTNNModule.cpp index 24bd05c8f9..a2bdb6e041 100644 --- a/python/TTNNModule.cpp +++ b/python/TTNNModule.cpp @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "mlir/CAPI/AffineMap.h" #include "ttmlir/Bindings/Python/TTMLIRModule.h" namespace mlir::ttmlir::python { @@ -84,23 +85,26 @@ void populateTTNNModule(py::module &m) { tt::ttnn::BufferTypeAttr bufferTypeAttr, tt::ttnn::ShardSpecAttr shardSpecAttr) { return wrap(tt::ttnn::MemoryConfigAttr::get( - unwrap(ctx), tensorMemoryLayoutAttr, bufferTypeAttr, - shardSpecAttr)); + unwrap(ctx), bufferTypeAttr, shardSpecAttr, + tensorMemoryLayoutAttr)); }) .def_static( "get_by_value", [](MlirContext ctx, uint32_t tensorMemoryLayout, uint32_t bufferType, std::vector shardShape) { - return wrap(tt::ttnn::MemoryConfigAttr::get( - unwrap(ctx), + tt::ttnn::TensorMemoryLayoutAttr layoutAttr = tt::ttnn::TensorMemoryLayoutAttr::get( unwrap(ctx), static_cast( - tensorMemoryLayout)), + tensorMemoryLayout)); + + return wrap(tt::ttnn::MemoryConfigAttr::get( + unwrap(ctx), tt::ttnn::BufferTypeAttr::get( unwrap(ctx), static_cast(bufferType)), tt::ttnn::ShardSpecAttr::get( unwrap(ctx), - tt::ttnn::ShapeAttr::get(unwrap(ctx), shardShape)))); + tt::ttnn::ShapeAttr::get(unwrap(ctx), shardShape)), + layoutAttr)); }) .def_property_readonly("tensor_memory_layout", &tt::ttnn::MemoryConfigAttr::getTensorMemoryLayout) @@ -127,5 +131,37 @@ void populateTTNNModule(py::module &m) { }) .def_property_readonly("y", &tt::ttnn::MeshShapeAttr::getY) .def_property_readonly("x", &tt::ttnn::MeshShapeAttr::getX); + + tt_attribute_class(m, "TTNNLayoutAttr") + .def_static( + "get", + [](MlirContext ctx, MlirAffineMap linear, MlirAttribute grid, + MlirType memref, + std::optional memLayout = std::nullopt) { + tt::ttnn::TensorMemoryLayoutAttr memLayoutAttr; + if (memLayout.has_value()) { + memLayoutAttr = tt::ttnn::TensorMemoryLayoutAttr::get( + unwrap(ctx), + static_cast(memLayout.value())); + } + return wrap(tt::ttnn::TTNNLayoutAttr::get( + unwrap(ctx), mlir::cast(unwrap(linear)), + mlir::cast(unwrap(grid)), + mlir::cast(unwrap(memref)), memLayoutAttr)); + }) + .def_property_readonly( + "linear", + [](tt::ttnn::TTNNLayoutAttr self) { return wrap(self.getLinear()); }) + .def_property_readonly("grid_attr", &tt::ttnn::TTNNLayoutAttr::getGrid) + .def_property_readonly( + "memref", + [](tt::ttnn::TTNNLayoutAttr self) { return wrap(self.getMemref()); }) + .def_property_readonly( + "memory_layout_as_int", [](tt::ttnn::TTNNLayoutAttr self) { + if (!self.getMemLayout()) { + assert(false && "Memory layout is not set"); + } + return static_cast(self.getMemLayout().getValue()); + }); } } // namespace mlir::ttmlir::python diff --git a/python/test_infra/test_optimizer_overrides.py b/python/test_infra/test_optimizer_overrides.py new file mode 100644 index 0000000000..68ea33c8e7 --- /dev/null +++ b/python/test_infra/test_optimizer_overrides.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + + +import ttmlir.optimizer_overrides as oo + +from ttmlir.optimizer_overrides import MemoryLayoutAnalysisPolicyType +from ttmlir.optimizer_overrides import BufferType +from ttmlir.optimizer_overrides import Layout +from ttmlir.optimizer_overrides import TensorMemoryLayout +from ttmlir.optimizer_overrides import DataType + + +def main(): + + print("\n\n ================ TESTING START ================ \n\n") + + # ----------------------------------------------------------------------------- # + # Instantiate the OptimizerOverridesHandler + # ----------------------------------------------------------------------------- # + obj = oo.OptimizerOverridesHandler() + + # ----------------------------------------------------------------------------- # + # Test setters and getters + # ----------------------------------------------------------------------------- # + + # Enable Optimizer + obj.set_enable_optimizer(True) + print(f"Enable optimizer: {obj.get_enable_optimizer()}") + obj.set_enable_optimizer(False) + print(f"Enable optimizer: {obj.get_enable_optimizer()}") + + # Memory Reconfig + obj.set_memory_reconfig(True) + print(f"Memory Reconfig: {obj.get_memory_reconfig()}") + obj.set_memory_reconfig(False) + print(f"Memory Reconfig: {obj.get_memory_reconfig()}") + + # Enable Memory Layout Analysis + obj.set_enable_memory_layout_analysis(True) + print(f"Enable Memory Layout Analysis: {obj.get_enable_memory_layout_analysis()}") + obj.set_enable_memory_layout_analysis(False) + print(f"Enable Memory Layout Analysis: {obj.get_enable_memory_layout_analysis()}") + + # Enable Memory Layout Analysis Policy + obj.set_enable_memory_layout_analysis_policy(True) + print( + f"Enable Memory Layout Analysis Policy: {obj.get_enable_memory_layout_analysis_policy()}" + ) + obj.set_enable_memory_layout_analysis_policy(False) + print( + f"Enable Memory Layout Analysis Policy: {obj.get_enable_memory_layout_analysis_policy()}" + ) + + # Memory Layout Analysis Policy + obj.set_memory_layout_analysis_policy(MemoryLayoutAnalysisPolicyType.DFSharding) + print(f"Memory Layout Analysis Policy: {obj.get_memory_layout_analysis_policy()}") + obj.set_memory_layout_analysis_policy(MemoryLayoutAnalysisPolicyType.L1Interleaved) + print(f"Memory Layout Analysis Policy: {obj.get_memory_layout_analysis_policy()}") + + # System Descriptor Path + obj.set_system_desc_path("System Descriptor Path") + print(f"System Descriptor Path: {obj.get_system_desc_path()}") + + # Max Legal Layouts + obj.set_max_legal_layouts(10) + print(f"Max Legal Layouts: {obj.get_max_legal_layouts()}") + + # Mesh Shape + obj.set_mesh_shape([1, 2, 3]) + print(f"Mesh Shape: {obj.get_mesh_shape()}") + + # ----------------------------------------------------------------------------- # + # Test Input Layout and Output Layout + # ----------------------------------------------------------------------------- # + + # Input Layout + obj.add_input_layout_override("add", [1, 2]) + obj.add_input_layout_override("mul", [0, 1]) + obj.add_input_layout_override("sub", [2, 3]) + print(f"Input Layout: {obj.get_input_layout_overrides()}\n") + + # Output Layout + obj.add_output_layout_override( + "add", + [0, 1], + BufferType.DRAM, + TensorMemoryLayout.HeightSharded, + Layout.RowMajor, + DataType.Float16, + ) + obj.add_output_layout_override( + "mul", + [1, 2], + BufferType.L1, + TensorMemoryLayout.WidthSharded, + Layout.RowMajor, + DataType.BFloat16, + ) + obj.add_output_layout_override( + "sub", + [2, 3], + BufferType.SystemMemory, + TensorMemoryLayout.BlockSharded, + Layout.Tile, + DataType.UInt16, + ) + print(f"Output Layout: {obj.get_output_layout_overrides()}\n") + + # ----------------------------------------------------------------------------- # + # Test string method + # ----------------------------------------------------------------------------- # + print(f"Optimizer Override string: {obj.to_string()}") + + print("\n\n ================ TESTING END ================ \n\n") + + +if __name__ == "__main__": + main() diff --git a/python/ttmlir/dialects/ttnn.py b/python/ttmlir/dialects/ttnn.py index d81f58111a..659938cf66 100644 --- a/python/ttmlir/dialects/ttnn.py +++ b/python/ttmlir/dialects/ttnn.py @@ -3,4 +3,5 @@ # SPDX-License-Identifier: Apache-2.0 from ._ttnn_ops_gen import * +from ._ttnn_enum_gen import * from .._mlir_libs._ttmlir import register_dialect, ttnn_ir as ir diff --git a/python/ttmlir/optimizer_overrides.py b/python/ttmlir/optimizer_overrides.py new file mode 100644 index 0000000000..880ab85cfc --- /dev/null +++ b/python/ttmlir/optimizer_overrides.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from ._mlir_libs._ttmlir.optimizer_overrides import * diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt index c9dce10946..0a23c6ddac 100644 --- a/runtime/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -14,6 +14,7 @@ set(TT_RUNTIME_OPTIONS TT_RUNTIME_DEBUG TT_RUNTIME_ENABLE_PERF_TRACE TT_RUNTIME_WORKAROUNDS + TTMLIR_ENABLE_RUNTIME_TESTS ) foreach(OPTION ${TT_RUNTIME_OPTIONS}) @@ -24,6 +25,4 @@ endforeach() add_subdirectory(lib) add_subdirectory(tools) -if (TTMLIR_ENABLE_RUNTIME_TESTS) - add_subdirectory(test) -endif() +add_subdirectory(test) diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index 7a68a7e944..1b043f6e58 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -39,14 +39,20 @@ void closeDevice(Device device); void deallocateBuffers(Device device); -Event submit(Device device, Binary executable, std::uint32_t programIndex, - std::vector const &inputs, - std::vector const &outputs); - void wait(Event event); +void wait(Tensor tensor); + +void wait(std::vector const &tensors); + +Event submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, std::vector const &inputs, + std::vector const &outputs); + std::string getOpDebugString(OpContext opContextHandle); +std::string getOpLocInfo(OpContext opContextHandle); + Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 6c55ac1de7..268959e8a2 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -15,6 +15,7 @@ #include "ttnn/operations/copy.hpp" #include "ttnn/operations/core/core.hpp" #include "ttnn/operations/creation.hpp" +#include "ttnn/operations/data_movement/clone/clone.hpp" #include "ttnn/operations/data_movement/concat/concat.hpp" #include "ttnn/operations/data_movement/permute/permute.hpp" #include "ttnn/operations/data_movement/transpose/transpose.hpp" @@ -23,11 +24,14 @@ #include "ttnn/operations/eltwise/ternary/where.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" #include "ttnn/operations/embedding/embedding.hpp" +#include "ttnn/operations/kv_cache/kv_cache.hpp" #include "ttnn/operations/matmul/matmul.hpp" #include "ttnn/operations/normalization/softmax/softmax.hpp" -#include "ttnn/operations/pool/maxpool/max_pool2d.hpp" +#include "ttnn/operations/pool/generic/generic_pools.hpp" #include "ttnn/operations/reduction/generic/generic_reductions.hpp" #include "ttnn/tensor/host_buffer/functions.hpp" +#include "ttnn/tensor/host_buffer/owned_buffer.hpp" +#include "ttnn/tensor/shape/shape.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/types.hpp" @@ -53,16 +57,27 @@ createTensor(std::vector> &data, ::tt::target::DataType dataType, std::unordered_map const &strategy); +Tensor createTensor(Device device, Layout layout, + std::vector const &shape, + std::vector const &stride, + std::uint32_t itemsize); + inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { - return createTensor(data, desc.shape, desc.stride, desc.itemsize, - desc.dataType); + return ::tt::runtime::ttnn::createTensor(data, desc.shape, desc.stride, + desc.itemsize, desc.dataType); } inline Tensor createTensor(std::vector> &data, TensorDesc const &desc, std::unordered_map const &strategy) { - return createTensor(data, desc.shape, desc.stride, desc.itemsize, - desc.dataType, strategy); + return ::tt::runtime::ttnn::createTensor( + data, desc.shape, desc.stride, desc.itemsize, desc.dataType, strategy); +} + +inline Tensor createTensor(Device device, Layout layout, + TensorDesc const &desc) { + return ::tt::runtime::ttnn::createTensor(device, layout, desc.shape, + desc.stride, desc.itemsize); } tt::target::DataType getTensorDataType(Tensor tensor); @@ -75,23 +90,55 @@ void closeDevice(Device device); void deallocateBuffers(Device device); -Event submit(Device device, Binary executable, std::uint32_t programIndex, - std::vector const &inputs, - std::vector const &outputs); - void wait(Event event); +void wait(Tensor tensor); + +void wait(std::vector const &tensors); + +Tensor toHost(Tensor tensor, bool untilize = false); + +Tensor toLayout(Tensor tensor, Device device, Layout layout); + +Layout getLayout(Binary executableHandle, std::uint32_t programIndex, + std::uint32_t inputIndex); + +void memcpy(void *dst, Tensor src); + +void memcpy(Tensor dst, Tensor src); + +void deallocateTensor(Tensor &tensor, bool force = false); + std::string getOpDebugString(OpContext opContextHandle); +std::string getOpLocInfo(OpContext opContextHandle); + Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle); std::vector getTensorData(Tensor tensor); +namespace legacy { +/* Will be deprecated soon once FEs migrate to new API */ + +Event submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, std::vector const &inputs, + std::vector const &outputs); + void runProgram(::ttnn::MeshDevice &meshDevice, Binary &executableHandle, std::uint32_t programIndex, std::vector<::ttnn::Tensor *> const &inputs, std::vector<::ttnn::Tensor *> const &outputs); +} // namespace legacy + +std::vector submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputs); + +std::vector runProgram(::ttnn::MeshDevice &meshDevice, + Binary executableHandle, + std::uint32_t programIndex, + std::vector<::ttnn::Tensor *> const &inputs); } // namespace tt::runtime::ttnn diff --git a/runtime/include/tt/runtime/detail/workarounds.h b/runtime/include/tt/runtime/detail/workarounds.h index 38d8c08cf3..a586757522 100644 --- a/runtime/include/tt/runtime/detail/workarounds.h +++ b/runtime/include/tt/runtime/detail/workarounds.h @@ -15,29 +15,15 @@ struct Env { #else constexpr static Env #endif - get(bool ignoreTileShape = true, bool emptyOpForceRowMajor = true, - bool fullOpForceRowMajor = true, bool maxpool2dPreshard = true, - bool swapBinaryOperands = true) + get(bool maxpool2dPreshard = true, bool swapBinaryOperands = true, + bool readUpdateIndexFromDeviceForKVCache = true) #if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1 ; #else { - return Env(true, true, true, true, true); + return Env(true, true, true); } #endif - // TODO(bug #272), determine correct layout by tile shape in the future - // currently tile shape is not set correctly, so as a workaround, hardcode - // layout - bool ignoreTileShape; - - // TODO(bug #582): ttnn::empty doesn't work properly with tile layout, - // using ROW_MAJOR until we fix it - bool emptyOpForceRowMajor; - - // TODO(bug #582): ttnn::full doesn't work properly with tile layout, - // using ROW_MAJOR until we fix it - bool fullOpForceRowMajor; - // TODO(bug #855): Ideally we should have an op that preshards for maxpool2d // instead of adding a method in runtime bool maxpool2dPreshard; @@ -47,29 +33,30 @@ struct Env { // rhs operand). We should add this check in the compiler. bool swapBinaryOperands; + // TODO(bug #1510) ttnn::update_cache takes a single update index as a uint32 + // as a function argument. The tt-torch frontend and likely others model this + // as a tensor with integer elements. For now, to get this op to work we need + // to be able to pluck this update index from a runtime tensor. + bool readUpdateIndexFromDeviceForKVCache; + private: - constexpr Env(bool ignoreTileShape, bool emptyOpForceRowMajor, - bool fullOpForceRowMajor, bool maxpool2dPreshard, - bool swapBinaryOperands) - : ignoreTileShape(ignoreTileShape), - emptyOpForceRowMajor(emptyOpForceRowMajor), - fullOpForceRowMajor(fullOpForceRowMajor), - maxpool2dPreshard(maxpool2dPreshard), - swapBinaryOperands(swapBinaryOperands) {} + constexpr Env(bool maxpool2dPreshard, bool swapBinaryOperands, + bool readUpdateIndexFromDeviceForKVCache) + : maxpool2dPreshard(maxpool2dPreshard), + swapBinaryOperands(swapBinaryOperands), + readUpdateIndexFromDeviceForKVCache( + readUpdateIndexFromDeviceForKVCache) {} }; inline std::ostream &operator<<(std::ostream &os, const Env &env) { os << "workaround::Env{\n"; - os << "\t" - << "ignoreTileShape: " << env.ignoreTileShape << ",\n"; - os << "\t" - << "emptyOpForceRowMajor: " << env.emptyOpForceRowMajor << ",\n"; - os << "\t" - << "fullOpForceRowMajor: " << env.fullOpForceRowMajor << ",\n"; os << "\t" << "maxpool2dPreshard: " << env.maxpool2dPreshard << ",\n"; os << "\t" - << "swapBinaryOperands: " << env.swapBinaryOperands << "\n"; + << "swapBinaryOperands: " << env.swapBinaryOperands << ",\n"; + os << "\t" + << "readUpdateIndexFromDeviceForKVCache: " + << env.readUpdateIndexFromDeviceForKVCache << "\n"; os << "}"; return os; } diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 1dc721f662..c3b725e0f9 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -43,16 +43,27 @@ createTensor(std::vector> &data, ::tt::target::DataType dataType, std::unordered_map const &strategy); +Tensor createTensor(Device device, Layout layout, + std::vector const &shape, + std::vector const &stride, + std::uint32_t itemsize); + inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { - return createTensor(data, desc.shape, desc.stride, desc.itemsize, - desc.dataType); + return ::tt::runtime::createTensor(data, desc.shape, desc.stride, + desc.itemsize, desc.dataType); } inline Tensor createTensor(std::vector> &data, TensorDesc const &desc, std::unordered_map const &strategy) { - return createTensor(data, desc.shape, desc.stride, desc.itemsize, - desc.dataType, strategy); + return ::tt::runtime::createTensor(data, desc.shape, desc.stride, + desc.itemsize, desc.dataType, strategy); +} + +inline Tensor createTensor(Device device, Layout layout, + TensorDesc const &desc) { + return ::tt::runtime::createTensor(device, layout, desc.shape, desc.stride, + desc.itemsize); } tt::target::DataType getTensorDataType(Tensor tensor); @@ -63,19 +74,42 @@ Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1); void closeDevice(Device device); -Event submit(Device device, Binary executable, std::uint32_t programIndex, - std::vector const &inputs, - std::vector const &outputs); - void wait(Event event); +void wait(Tensor tensor); + +void wait(std::vector const &tensors); + +Tensor toHost(Tensor tensor, bool untilize = false); + +Tensor toLayout(Tensor tensor, Device device, Layout layout); + +Layout getLayout(Binary executableHandle, std::uint32_t programIndex, + std::uint32_t inputIndex); + +void memcpy(void *dst, Tensor src); + +void memcpy(Tensor dst, Tensor src); + +void deallocateTensor(Tensor &tensor, bool force = false); + std::string getOpDebugString(OpContext opContextHandle); +std::string getOpLocInfo(OpContext opContextHandle); + Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle); std::vector getTensorData(Tensor tensor); +std::vector submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputs); + +Event submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, std::vector const &inputs, + std::vector const &outputs); + } // namespace tt::runtime #endif diff --git a/runtime/include/tt/runtime/test/utils.h b/runtime/include/tt/runtime/test/utils.h new file mode 100644 index 0000000000..e4323cc165 --- /dev/null +++ b/runtime/include/tt/runtime/test/utils.h @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TT_RUNTIME_TEST_UTILS_H +#define TT_RUNTIME_TEST_UTILS_H + +#include "tt/runtime/types.h" + +// Utility functions for testing TTNN runtime +namespace tt::runtime::ttnn::test { +Layout getDramInterleavedTileLayout(::tt::target::DataType dataType); +Layout getDramInterleavedRowMajorLayout(::tt::target::DataType dataType); +Layout getHostRowMajorLayout(::tt::target::DataType dataType); +} // namespace tt::runtime::ttnn::test + +#endif // TT_RUNTIME_TEST_UTILS_H diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index 8fd641195f..cc2791e237 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -122,10 +122,20 @@ struct Event : public detail::RuntimeCheckedObjectImpl { struct Tensor : public detail::RuntimeCheckedObjectImpl { std::shared_ptr data; - + Event event; Tensor(std::shared_ptr handle, std::shared_ptr data, DeviceRuntime runtime) - : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {} + : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), + event(nullptr, runtime) {} + + Tensor(std::shared_ptr handle, std::shared_ptr data, + std::shared_ptr eventHandle, DeviceRuntime runtime) + : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), + event(eventHandle, runtime) {} +}; + +struct Layout : public detail::RuntimeCheckedObjectImpl { + using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl; }; struct CallbackContext : public detail::RuntimeCheckedObjectImpl { diff --git a/runtime/lib/binary.cpp b/runtime/lib/binary.cpp index 92be39d27f..1d8cbf38b2 100644 --- a/runtime/lib/binary.cpp +++ b/runtime/lib/binary.cpp @@ -27,15 +27,12 @@ static std::string asJson(void const *fbb, uint8_t const *binarySchema, flatbuffers::Parser parser(opts); if (not parser.Deserialize(binarySchema, schemaSize)) { - throw std::runtime_error("Failed to deserialize schema"); + LOG_FATAL("Failed to deserialize schema"); } std::string text; const char *err = ::flatbuffers::GenerateText(parser, fbb, &text); - if (err) { - throw std::runtime_error("Failed to generate JSON: " + std::string(err)); - } - + LOG_ASSERT(not err, "Failed to generate JSON: ", err); return text; } @@ -44,9 +41,7 @@ namespace ttnn { ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier( binary.handle.get()); - if (not isTTNN) { - throw std::runtime_error("Unsupported binary format"); - } + LOG_ASSERT(isTTNN, "Unsupported binary format"); return ::tt::target::ttnn::GetSizePrefixedTTNNBinary(binary.handle.get()); } @@ -128,9 +123,7 @@ ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) { bool isTTMetal = ::tt::target::metal::SizePrefixedTTMetalBinaryBufferHasIdentifier( binary.handle.get()); - if (not isTTMetal) { - throw std::runtime_error("Unsupported binary format"); - } + LOG_ASSERT(isTTMetal, "Unsupported binary format"); return ::tt::target::metal::GetSizePrefixedTTMetalBinary(binary.handle.get()); } @@ -207,7 +200,7 @@ namespace system_desc { ::tt::target::SystemDescRoot const *getBinary(Flatbuffer binary) { if (!::tt::target::SizePrefixedSystemDescRootBufferHasIdentifier( binary.handle.get())) { - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } return ::tt::target::GetSizePrefixedSystemDescRoot(binary.handle.get()); } @@ -234,10 +227,7 @@ std::string asJson(Flatbuffer binary) { Flatbuffer Flatbuffer::loadFromPath(char const *path) { // load a flatbuffer from path std::ifstream fbb(path, std::ios::binary | std::ios::ate); - if (!fbb.is_open()) { - throw std::runtime_error("Failed to open file: " + std::string(path)); - } - + LOG_ASSERT(fbb.is_open(), "Failed to open file: ", path); std::streampos size = fbb.tellg(); fbb.seekg(0, std::ios::beg); auto buffer = ::tt::runtime::utils::malloc_shared(size); @@ -269,7 +259,7 @@ std::string_view Flatbuffer::getFileIdentifier() const { return ::tt::target::SystemDescRootIdentifier(); } - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } std::string Flatbuffer::getVersion() const { @@ -288,7 +278,7 @@ std::string Flatbuffer::getVersion() const { return system_desc::getVersion(*this); } - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } std::string_view Flatbuffer::getTTMLIRGitHash() const { @@ -307,7 +297,7 @@ std::string_view Flatbuffer::getTTMLIRGitHash() const { return system_desc::getTTMLIRGitHash(*this); } - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } std::string Flatbuffer::asJson() const { @@ -326,7 +316,7 @@ std::string Flatbuffer::asJson() const { return system_desc::asJson(*this); } - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } SystemDesc SystemDesc::loadFromPath(char const *path) { @@ -349,7 +339,7 @@ Binary::getProgramInputs(std::uint32_t programIndex) const { return metal::getProgramInputs(*this, programIndex); } - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } std::vector @@ -364,7 +354,7 @@ Binary::getProgramOutputs(std::uint32_t programIndex) const { return metal::getProgramOutputs(*this, programIndex); } - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } const ::tt::target::GoldenTensor * @@ -379,8 +369,7 @@ Binary::getDebugInfoGolden(std::string &loc) const { return metal::getDebugInfoGolden(*this, loc); } - throw std::runtime_error( - "Unsupported binary format for obtaining golden information"); + LOG_FATAL("Unsupported binary format for obtaining golden information"); } } // namespace tt::runtime diff --git a/runtime/lib/common/system_desc.cpp b/runtime/lib/common/system_desc.cpp index f1210d00aa..cf0c6196d7 100644 --- a/runtime/lib/common/system_desc.cpp +++ b/runtime/lib/common/system_desc.cpp @@ -12,8 +12,10 @@ #define FMT_HEADER_ONLY #include "distributed/mesh_device.hpp" +#include "eth_l1_address_map.h" #include "host_api.hpp" #include "hostdevcommon/common_values.hpp" +#include "noc/noc_parameters.h" namespace tt::runtime::system_desc { static ::tt::target::Dim2d toFlatbuffer(const CoreCoord &coreCoord) { @@ -32,7 +34,7 @@ static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) { break; } - throw std::runtime_error("Unsupported arch"); + LOG_FATAL("Unsupported arch"); } static std::vector<::tt::target::ChipChannel> @@ -246,7 +248,7 @@ static std::unique_ptr<::tt::runtime::SystemDesc> getCurrentSystemDescImpl( ::tt::target::FinishSizePrefixedSystemDescRootBuffer(fbb, root); ::flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize()); if (!::tt::target::VerifySizePrefixedSystemDescRootBuffer(verifier)) { - throw std::runtime_error("Failed to verify system desc root buffer"); + LOG_FATAL("Failed to verify system desc root buffer"); } uint8_t *buf = fbb.GetBufferPointer(); auto size = fbb.GetSize(); diff --git a/runtime/lib/common/workarounds.cpp b/runtime/lib/common/workarounds.cpp index cd2795d023..a9dbf7564a 100644 --- a/runtime/lib/common/workarounds.cpp +++ b/runtime/lib/common/workarounds.cpp @@ -6,12 +6,10 @@ namespace tt::runtime::workaround { #if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1 -const Env &Env::get(bool ignoreTileShape, bool emptyOpForceRowMajor, - bool fullOpForceRowMajor, bool maxpool2dPreshard, - bool swapBinaryOperands) { - static const Env config(ignoreTileShape, emptyOpForceRowMajor, - fullOpForceRowMajor, maxpool2dPreshard, - swapBinaryOperands); +const Env &Env::get(bool maxpool2dPreshard, bool swapBinaryOperands, + bool readUpdateIndexFromDeviceForKVCache) { + static const Env config(maxpool2dPreshard, swapBinaryOperands, + readUpdateIndexFromDeviceForKVCache); return config; } #endif diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 586b8394ea..2da673ad19 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -42,7 +42,7 @@ void deallocateBuffers(Device device) { return ::tt::runtime::ttmetal::deallocateBuffers(device); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } } // namespace detail @@ -91,15 +91,14 @@ void setCompatibleRuntime(const Binary &binary) { return setCurrentRuntime(DeviceRuntime::TTMetal); } #endif - throw std::runtime_error( - "Unsupported binary file identifier or runtime not enabled"); + LOG_FATAL("Unsupported binary file identifier or runtime not enabled"); } std::pair getCurrentSystemDesc() { #if defined(TT_RUNTIME_ENABLE_TTNN) || defined(TT_RUNTIME_ENABLE_TTMETAL) return system_desc::getCurrentSystemDesc(); #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } Tensor createTensor(std::shared_ptr data, @@ -122,7 +121,7 @@ Tensor createTensor(std::shared_ptr data, dataType); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } Tensor @@ -143,10 +142,32 @@ createTensor(std::vector> &data, #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - throw std::runtime_error("Not implemented"); + LOG_FATAL("Not implemented"); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); +} + +Tensor createTensor(Device device, Layout layout, + std::vector const &shape, + std::vector const &stride, + std::uint32_t itemsize) { + LOG_ASSERT(not shape.empty()); + LOG_ASSERT(not stride.empty()); + LOG_ASSERT(itemsize > 0); +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::createTensor(device, layout, shape, stride, + itemsize); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("Not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); } tt::target::DataType getTensorDataType(Tensor tensor) { @@ -161,7 +182,7 @@ tt::target::DataType getTensorDataType(Tensor tensor) { return ::tt::runtime::ttmetal::getTensorDataType(tensor); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } size_t getNumAvailableDevices() { @@ -176,7 +197,7 @@ size_t getNumAvailableDevices() { return ::tt::runtime::ttmetal::getNumAvailableDevices(); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) { @@ -191,7 +212,7 @@ Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) { return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } void closeDevice(Device device) { @@ -206,44 +227,145 @@ void closeDevice(Device device) { return ::tt::runtime::ttmetal::closeDevice(device); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } -Event submit(Device deviceHandle, Binary executableHandle, - std::uint32_t programIndex, - std::vector const &inputHandles, - std::vector const &outputHandles) { +void wait(Event event) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle, - programIndex, inputHandles, - outputHandles); + LOG_WARNING("wait API will be deprecated for TTNN runtime."); + return ::tt::runtime::ttnn::wait(event); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::submit(deviceHandle, executableHandle, - programIndex, inputHandles, - outputHandles); + return ::tt::runtime::ttmetal::wait(event); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } -void wait(Event event) { +void wait(Tensor tensor) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::wait(event); + return ::tt::runtime::ttnn::wait(tensor); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::wait(event); + return ::tt::runtime::ttmetal::wait(tensor); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); +} + +void wait(std::vector const &tensors) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::wait(tensors); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::wait(tensors); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +Tensor toHost(Tensor tensor, bool untilize) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::toHost(tensor, untilize); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +Tensor toLayout(Tensor tensor, Device device, Layout layout) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::toLayout(tensor, device, layout); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +Layout getLayout(Binary executableHandle, std::uint32_t programIndex, + std::uint32_t inputIndex) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getLayout(executableHandle, programIndex, + inputIndex); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +void memcpy(void *dst, Tensor src) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::memcpy(dst, src); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +void memcpy(Tensor dst, Tensor src) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::memcpy(dst, src); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +void deallocateTensor(Tensor &tensor, bool force) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::deallocateTensor(tensor, force); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); } std::string getOpDebugString(OpContext opContextHandle) { @@ -257,6 +379,21 @@ std::string getOpDebugString(OpContext opContextHandle) { if (getCurrentRuntime() == DeviceRuntime::TTMetal) { return ::tt::runtime::ttmetal::getOpDebugString(opContextHandle); } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::string getOpLocInfo(OpContext opContextHandle) { +#ifdef TT_RUNTIME_ENABLE_TTNN + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getOpLocInfo(opContextHandle); + } +#endif + +#ifdef TT_RUNTIME_ENABLE_TTMETAL + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getOpLocInfo(opContextHandle); + } #endif throw std::runtime_error("runtime is not enabled"); } @@ -276,7 +413,7 @@ Tensor getOpOutputTensor(OpContext opContextHandle, programContextHandle); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } std::vector getTensorData(Tensor tensor) { @@ -292,7 +429,48 @@ std::vector getTensorData(Tensor tensor) { } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } +std::vector submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputHandles) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle, + programIndex, inputHandles); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +Event submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputHandles, + std::vector const &outputHandles) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + LOG_WARNING("This submit API will soon be deprecated. Please switch to the " + "new API."); + return ::tt::runtime::ttnn::legacy::submit(deviceHandle, executableHandle, + programIndex, inputHandles, + outputHandles); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::submit(deviceHandle, executableHandle, + programIndex, inputHandles, + outputHandles); + } +#endif + LOG_FATAL("runtime is not enabled"); +} } // namespace tt::runtime diff --git a/runtime/lib/ttmetal/CMakeLists.txt b/runtime/lib/ttmetal/CMakeLists.txt index 3706d74333..f31fad2530 100644 --- a/runtime/lib/ttmetal/CMakeLists.txt +++ b/runtime/lib/ttmetal/CMakeLists.txt @@ -10,7 +10,7 @@ target_include_directories(TTRuntimeTTMetal PUBLIC ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common ) target_include_directories(TTRuntimeTTMetal SYSTEM PUBLIC "$") -target_link_libraries(TTRuntimeTTMetal PUBLIC TTMETAL_LIBRARY) -add_dependencies(TTRuntimeTTMetal TTMETAL_LIBRARY tt-metal FBS_GENERATION) +target_link_libraries(TTRuntimeTTMetal PUBLIC TTMETAL_LIBRARY DEVICE_LIBRARY) +add_dependencies(TTRuntimeTTMetal TTMETAL_LIBRARY DEVICE_LIBRARY tt-metal FBS_GENERATION) # Optionally compile profiling code and link tracy client for perf profiling. diff --git a/runtime/lib/ttmetal/command_queue.cpp b/runtime/lib/ttmetal/command_queue.cpp index 9a408a66b1..3480458e6a 100644 --- a/runtime/lib/ttmetal/command_queue.cpp +++ b/runtime/lib/ttmetal/command_queue.cpp @@ -137,7 +137,7 @@ void CQExecutor::execute(::tt::target::metal::Command const *command) { break; } default: - throw std::runtime_error("Unsupported command type"); + LOG_FATAL("Unsupported command type"); break; } } @@ -328,7 +328,7 @@ createKernelConfig(::tt::target::metal::KernelSource const *kernelSource) { break; } } - throw std::runtime_error("Unsupported kernel source type"); + LOG_FATAL("Unsupported kernel source type"); } static ::tt::DataFormat toDataFormat(::tt::target::DataType dataType) { @@ -346,7 +346,7 @@ static ::tt::DataFormat toDataFormat(::tt::target::DataType dataType) { case ::tt::target::DataType::UInt8: return ::tt::DataFormat::UInt8; default: - throw std::runtime_error("Unsupported data type"); + LOG_FATAL("Unsupported data type"); } } @@ -358,7 +358,7 @@ static CoreType toCoreType(::tt::target::metal::CoreType coreType) { case ::tt::target::metal::CoreType::ETH: return CoreType::ETH; } - throw std::runtime_error("Unsupported core type"); + LOG_FATAL("Unsupported core type"); } static ::tt::tt_metal::CircularBufferConfig createCircularBufferConfig( @@ -427,7 +427,7 @@ static void processRuntimeArgs( break; } case ::tt::target::metal::RuntimeArg::NONE: - throw std::runtime_error("Unsupported runtime arg type"); + LOG_FATAL("Unsupported runtime arg type"); } } @@ -516,7 +516,7 @@ void CQExecutor::execute( break; } default: - throw std::runtime_error("Unsupported HostBuffer type"); + LOG_FATAL("Unsupported HostBuffer type"); } } @@ -524,7 +524,7 @@ void CQExecutor::execute( ::tt::target::metal::EnqueueReadBufferCommand const *command) { ZoneScopedN("EnqueueReadBufferCommand"); // Maybe we will need this in the future, like paging to system mem? - throw std::runtime_error("Unsupported EnqueueReadBufferCommand"); + LOG_FATAL("Unsupported EnqueueReadBufferCommand"); } void CQExecutor::execute( diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index ab343554ed..2a66aa5e65 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -24,7 +24,7 @@ static ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) { ::tt::target::metal::SizePrefixedTTMetalBinaryBufferHasIdentifier( binary.handle.get()); if (not isTTMetal) { - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } return ::tt::target::metal::GetSizePrefixedTTMetalBinary(binary.handle.get()); } @@ -56,7 +56,7 @@ tt::target::DataType getTensorDataType(Tensor tensor) { } if (std::holds_alternative>( metalTensor)) { - throw std::runtime_error("Datatype mapping from buffer not supported yet."); + LOG_FATAL("Datatype mapping from buffer not supported yet."); } LOG_ASSERT(false, "Unsupported tensor type"); return ::tt::target::DataType::Float32; @@ -96,6 +96,21 @@ void deallocateBuffers(Device deviceHandle) { } } +void wait(Event event) { + Events events = event.as(DeviceRuntime::TTMetal); + for (auto e : events) { + ::tt::tt_metal::EventSynchronize(e); + } +} + +void wait(Tensor tensor) { ::tt::runtime::ttmetal::wait(tensor.event); } + +void wait(std::vector const &tensors) { + for (Tensor tensor : tensors) { + ::tt::runtime::ttmetal::wait(tensor); + } +} + static std::pair, std::shared_ptr<::tt::tt_metal::Event>> prepareInput(::tt::tt_metal::Device *device, MetalTensor const &metalTensor, @@ -117,7 +132,7 @@ prepareInput(::tt::tt_metal::Device *device, MetalTensor const &metalTensor, metalTensor)) { std::shared_ptr<::tt::tt_metal::Buffer> buffer = std::get>(metalTensor); - throw std::runtime_error("Input from buffer not supported yet"); + LOG_FATAL("Input from buffer not supported yet"); } LOG_ASSERT(false, "Unsupported tensor type"); return std::make_pair(nullptr, nullptr); @@ -249,19 +264,18 @@ Event submit(Device deviceHandle, Binary executableHandle, return Event(static_pointer_cast(events), DeviceRuntime::TTMetal); } -void wait(Event event) { - Events events = event.as(DeviceRuntime::TTMetal); - for (auto e : events) { - ::tt::tt_metal::EventSynchronize(e); - } -} - std::string getOpDebugString(OpContext opContextHandle) { // Not implemented LOG_WARNING("obtaining op debug string for metal runtime not implemented"); return ""; } +std::string getOpLocInfo(OpContext opContextHandle) { + // Not implemented + LOG_WARNING("obtaining op location info for metal runtime not implemented"); + return ""; +} + Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle) { // Not implemented diff --git a/runtime/lib/ttnn/CMakeLists.txt b/runtime/lib/ttnn/CMakeLists.txt index 92581cf46f..6a68c4c7b9 100644 --- a/runtime/lib/ttnn/CMakeLists.txt +++ b/runtime/lib/ttnn/CMakeLists.txt @@ -1,4 +1,22 @@ +add_library(TTRuntimeTTNNHelpers + STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/types.cpp +) +set_property(TARGET TTRuntimeTTNNHelpers PROPERTY CXX_STANDARD 20) +target_compile_options(TTRuntimeTTNNHelpers PUBLIC -mavx -mavx2 -fsized-deallocation) +target_include_directories(TTRuntimeTTNNHelpers PUBLIC + ${PROJECT_SOURCE_DIR}/runtime/include + ${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/include + ${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/operations/include + ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common +) +target_include_directories(TTRuntimeTTNNHelpers SYSTEM PUBLIC "$") +add_dependencies(TTRuntimeTTNNHelpers TTNN_LIBRARY tt-metal FBS_GENERATION) +target_link_libraries(TTRuntimeTTNNHelpers PUBLIC TTNN_LIBRARY) + add_subdirectory(operations) + add_library(TTRuntimeTTNN STATIC runtime.cpp @@ -11,5 +29,5 @@ target_include_directories(TTRuntimeTTNN PUBLIC ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common ) target_include_directories(TTRuntimeTTNN SYSTEM PUBLIC "$") -target_link_libraries(TTRuntimeTTNN PUBLIC TTRuntimeTTNNOps) +target_link_libraries(TTRuntimeTTNN PUBLIC TTRuntimeTTNNOps TTRuntimeTTNNHelpers) add_dependencies(TTRuntimeTTNN TTRuntimeTTNNOps) diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp new file mode 100644 index 0000000000..87d0815992 --- /dev/null +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp @@ -0,0 +1,437 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt/runtime/ttnn/types.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/ttnn/utils.h" + +namespace tt::runtime::ttnn { + +// +// LayoutConverter APIs +// +LayoutConverter::LayoutConverter(const LayoutDesc &inputDesc, + const LayoutDesc &outputDesc) + : inputDesc(inputDesc), outputDesc(outputDesc) { + shouldTilize = (inputDesc.layout == ::ttnn::Layout::ROW_MAJOR and + outputDesc.layout == ::ttnn::Layout::TILE); + shouldUntilize = (inputDesc.layout == ::ttnn::Layout::TILE and + outputDesc.layout == ::ttnn::Layout::ROW_MAJOR); + shouldTypecast = (inputDesc.dataType != outputDesc.dataType); + shouldToDevice = (inputDesc.isOnHost() and outputDesc.isOnDevice()); + shouldToMemoryConfig = (not shouldToDevice and outputDesc.isOnDevice() and + (inputDesc.memoryConfig != outputDesc.memoryConfig)); + shouldFromDevice = (inputDesc.isOnDevice() and outputDesc.isOnHost()); +} + +::ttnn::Tensor LayoutConverter::convertTensorLayout( + const ::ttnn::Tensor &input, std::optional targetDevice) { + if (inputDesc.isOnHost()) { + return convertHostTensorLayout(input, targetDevice); + } + return convertDeviceTensorLayout(input); +} + +::ttnn::Tensor LayoutConverter::toLayoutIfNeeded(const ::ttnn::Tensor &input) { + if (shouldTilize) { + return ::ttnn::to_layout(input, ::ttnn::Layout::TILE, std::nullopt, + std::nullopt, + static_cast<::ttnn::Device *>(nullptr)); + } + if (shouldUntilize) { + return ::ttnn::to_layout(input, ::ttnn::Layout::ROW_MAJOR, std::nullopt, + std::nullopt, + static_cast<::ttnn::Device *>(nullptr)); + } + return input; +} + +::ttnn::Tensor LayoutConverter::typecastIfNeeded(const ::ttnn::Tensor &input) { + if (shouldTypecast) { + return ::ttnn::typecast(input, outputDesc.dataType); + } + return input; +} + +::ttnn::Tensor +LayoutConverter::toDeviceIfNeeded(const ::ttnn::Tensor &input, + std::optional targetDevice, + bool force) { + if (shouldToDevice or force) { + LOG_ASSERT(targetDevice.has_value()); + return std::visit( + [&](auto &&targetDevice) -> ::ttnn::Tensor { + return ::ttnn::to_device(input, &(targetDevice.get()), + outputDesc.memoryConfig); + }, + targetDevice.value()); + } + return input; +} + +::ttnn::Tensor +LayoutConverter::toMemoryConfigIfNeeded(const ::ttnn::Tensor &input) { + if (shouldToMemoryConfig) { + LOG_ASSERT(outputDesc.memoryConfig.has_value()); + return ::ttnn::to_memory_config(input, outputDesc.memoryConfig.value()); + } + return input; +} + +::ttnn::Tensor +LayoutConverter::fromDeviceIfNeeded(const ::ttnn::Tensor &input) { + if (shouldFromDevice) { + return ::ttnn::from_device(input); + } + return input; +} + +::ttnn::Tensor LayoutConverter::handleHostInputNoLayoutNoTypecast( + const ::ttnn::Tensor &input, std::optional targetDevice) { + ::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice); + out = toMemoryConfigIfNeeded(out); + return out; +} + +::ttnn::Tensor LayoutConverter::handleHostInputLayoutNoTypecast( + const ::ttnn::Tensor &input, std::optional targetDevice) { + if (shouldUntilize) { + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = toDeviceIfNeeded(out, targetDevice); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (shouldTilize and outputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice); + out = toLayoutIfNeeded(out); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (shouldTilize and outputDesc.dataType != ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = toDeviceIfNeeded(out, targetDevice); + out = toMemoryConfigIfNeeded(out); + return out; + } + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor LayoutConverter::handleHostInputNoLayoutTypecast( + const ::ttnn::Tensor &input, std::optional targetDevice) { + if (outputDesc.layout == ::ttnn::Layout::TILE) { + ::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice); + out = typecastIfNeeded(out); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (outputDesc.layout != ::ttnn::Layout::TILE) { + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toDeviceIfNeeded(out, targetDevice); + out = toMemoryConfigIfNeeded(out); + return out; + } + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor LayoutConverter::handleHostInputLayoutTypecast( + const ::ttnn::Tensor &input, std::optional targetDevice) { + if (shouldUntilize) { + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toLayoutIfNeeded(out); + out = toDeviceIfNeeded(out, targetDevice); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice); + out = toLayoutIfNeeded(out); + out = typecastIfNeeded(out); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (shouldTilize and outputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toDeviceIfNeeded(out, targetDevice); + out = toLayoutIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + outputDesc.dataType != ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toLayoutIfNeeded(out); + out = toDeviceIfNeeded(out, targetDevice); + out = toMemoryConfigIfNeeded(out); + return out; + } + + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor LayoutConverter::convertHostTensorLayout( + const ::ttnn::Tensor &input, std::optional targetDevice) { + bool shouldToLayout = (shouldTilize or shouldUntilize); + LOG_ASSERT(not shouldToDevice or targetDevice.has_value(), + "Target device must be provided for ToDevice"); + if (not shouldToLayout and not shouldTypecast) { + return handleHostInputNoLayoutNoTypecast(input, targetDevice); + } + if (shouldToLayout and not shouldTypecast) { + return handleHostInputLayoutNoTypecast(input, targetDevice); + } + if (not shouldToLayout and shouldTypecast) { + return handleHostInputNoLayoutTypecast(input, targetDevice); + } + if (shouldToLayout and shouldTypecast) { + return handleHostInputLayoutTypecast(input, targetDevice); + } + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor LayoutConverter::handleDeviceInputNoLayoutNoTypecast( + const ::ttnn::Tensor &input) { + ::ttnn::Tensor out = toMemoryConfigIfNeeded(input); + out = fromDeviceIfNeeded(out); + return out; +} + +::ttnn::Tensor LayoutConverter::handleDeviceInputLayoutNoTypecast( + const ::ttnn::Tensor &input) { + if (shouldUntilize and shouldFromDevice) { + ::ttnn::Tensor out = fromDeviceIfNeeded(input); + out = toLayoutIfNeeded(out); + return out; + } + + if (shouldUntilize and not shouldFromDevice) { + LOG_WARNING("Currently no constraint checking for on-device untilize."); + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + return out; + } + + /* If we should tilize and the input data type is bfloat16, tilize on device + */ + if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + out = fromDeviceIfNeeded(out); + return out; + } + + /* If we should tilize and the input data type is not bfloat16, tilize on + * host */ + if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + shouldFromDevice) { + ::ttnn::Tensor out = fromDeviceIfNeeded(input); + out = toLayoutIfNeeded(out); + return out; + } + + if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + not shouldFromDevice) { + LOG_WARNING("Currently no constraint checking for on-device tilize."); + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + return out; + } + + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor LayoutConverter::handleDeviceInputNoLayoutTypecast( + const ::ttnn::Tensor &input) { + if (inputDesc.isTilized()) { + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + out = fromDeviceIfNeeded(input); + return out; + } + + if (not inputDesc.isTilized() and shouldFromDevice) { + ::ttnn::Tensor out = fromDeviceIfNeeded(input); + out = typecastIfNeeded(out); + return out; + } + + if (not inputDesc.isTilized() and not shouldFromDevice) { + LOG_WARNING("Currently no constraint checking for on-device typecast."); + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + return out; + } + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor +LayoutConverter::handleDeviceInputLayoutTypecast(const ::ttnn::Tensor &input) { + if (shouldUntilize and shouldFromDevice) { + ::ttnn::Tensor out = typecastIfNeeded(input); + out = fromDeviceIfNeeded(input); + out = toLayoutIfNeeded(out); + return out; + } + + if (shouldUntilize and not shouldFromDevice) { + LOG_WARNING("Currently no constraint checking for on-device untilize."); + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toLayoutIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = typecastIfNeeded(out); + out = toMemoryConfigIfNeeded(out); + out = fromDeviceIfNeeded(out); + return out; + } + + if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + shouldFromDevice) { + ::ttnn::Tensor out = fromDeviceIfNeeded(input); + out = toLayoutIfNeeded(out); + out = typecastIfNeeded(out); + return out; + } + + if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + not shouldFromDevice) { + LOG_WARNING("Currently no constraint checking for on-device tilize."); + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = typecastIfNeeded(out); + out = toMemoryConfigIfNeeded(out); + return out; + } + + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor +LayoutConverter::convertDeviceTensorLayout(const ::ttnn::Tensor &input) { + bool shouldToLayout = (shouldTilize or shouldUntilize); + if (not shouldToLayout and not shouldTypecast) { + return handleDeviceInputNoLayoutNoTypecast(input); + } + if (shouldToLayout and not shouldTypecast) { + return handleDeviceInputLayoutNoTypecast(input); + } + if (not shouldToLayout and shouldTypecast) { + return handleDeviceInputNoLayoutTypecast(input); + } + if (shouldToLayout and shouldTypecast) { + return handleDeviceInputLayoutTypecast(input); + } + LOG_FATAL("Unreachable code path"); +} + +// +// ProgramTensorPool APIs +// +std::pair::iterator, bool> +ProgramTensorPool::try_emplace(std::uint32_t globalId, + const ::ttnn::Tensor &tensor) { + auto it = liveTensors.find(globalId); + if (it != liveTensors.end()) { + return std::make_pair(it, false); + } + LOG_ASSERT(!intermedTensors.contains(globalId)); + intermedTensors.try_emplace(globalId, tensor); + return liveTensors.try_emplace(globalId, &intermedTensors.at(globalId)); +} + +std::pair::iterator, bool> +ProgramTensorPool::insert_or_assign(std::uint32_t globalId, + const ::ttnn::Tensor &tensor) { + intermedTensors.insert_or_assign(globalId, tensor); + return liveTensors.insert_or_assign(globalId, &intermedTensors.at(globalId)); +} + +::ttnn::Tensor &ProgramTensorPool::at(std::uint32_t globalId) { + LOG_ASSERT(liveTensors.contains(globalId)); + return *liveTensors.at(globalId); +} + +const ::ttnn::Tensor &ProgramTensorPool::at(std::uint32_t globalId) const { + LOG_ASSERT(liveTensors.contains(globalId)); + return *liveTensors.at(globalId); +} + +size_t ProgramTensorPool::erase(std::uint32_t globalId) { + LOG_ASSERT(liveTensors.contains(globalId) && + intermedTensors.contains(globalId)); + intermedTensors.erase(globalId); + return liveTensors.erase(globalId); +} + +std::vector ProgramTensorPool::gatherOutputTensors() { + std::vector outputTensors; + outputTensors.reserve(programOutputs.size()); + std::transform( + programOutputs.begin(), programOutputs.end(), + std::back_inserter(outputTensors), [this](uint32_t outputGlobalId) { + return utils::createRuntimeTensorFromTTNN(this->at(outputGlobalId)); + }); + return outputTensors; +} + +// +// ProgramContext APIs +// +ProgramContext::ProgramContext( + const std::unordered_map &liveTensors, + const std::vector &programInputs, + const std::vector &programOutputs, ::ttnn::MeshDevice *parentMesh) + : tensorPool(ProgramTensorPool(liveTensors, programInputs, programOutputs)), + parentMesh(parentMesh) { + LOG_ASSERT(parentMesh, "Parent mesh cannot be null"); +} + +void ProgramContext::addSubMesh(uint32_t meshId, + std::shared_ptr<::ttnn::MeshDevice> subMesh) { + auto [it, inserted] = subMeshes.try_emplace(meshId, subMesh); + LOG_ASSERT(inserted, "Submesh already exists"); +} + +::ttnn::MeshDevice &ProgramContext::getSubMesh(uint32_t meshId) { + LOG_ASSERT(subMeshes.contains(meshId)); + return *subMeshes.at(meshId); +} + +size_t ProgramContext::subMeshSize(uint32_t meshId) const { + LOG_ASSERT(subMeshes.contains(meshId)); + return subMeshes.at(meshId)->num_devices(); +} + +::ttnn::Device &ProgramContext::getDeviceFromSubMesh(uint32_t meshId, + int physicalDeviceId) { + LOG_ASSERT(subMeshes.contains(meshId)); + auto &subMesh = *subMeshes.at(meshId); + return *subMesh.get_device(physicalDeviceId); +} + +::ttnn::Device &ProgramContext::getDeviceIndexFromSubMesh(uint32_t meshId, + int deviceIndex) { + LOG_ASSERT(subMeshes.contains(meshId)); + auto &subMesh = *subMeshes.at(meshId); + return *subMesh.get_device_index(deviceIndex); +} + +DeviceVariant ProgramContext::getTargetDevice(uint32_t meshId) { + LOG_ASSERT(subMeshes.contains(meshId)); + auto &subMesh = *subMeshes.at(meshId); + if (subMesh.num_devices() == 1) { + return std::ref(*subMesh.get_device_index(0)); + } + return std::ref(subMesh); +} + +} // namespace tt::runtime::ttnn diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h index 5cd08c7ed0..a5ca800c33 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h @@ -6,18 +6,88 @@ #define TT_RUNTIME_TTNN_TYPES_H #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/types.h" +#include +#include namespace tt::runtime::ttnn { - -using TensorMap = std::unordered_map; using DeviceVariant = std::variant, std::reference_wrapper<::ttnn::MeshDevice>>; +struct LayoutDesc { + ::ttnn::BufferType bufferType; + ::ttnn::Layout layout; + ::ttnn::DataType dataType; + std::optional<::ttnn::MemoryConfig> memoryConfig; + + LayoutDesc(const ::ttnn::BufferType &bufferType, const ::ttnn::Layout &layout, + const ::ttnn::DataType &dataType, + const std::optional<::ttnn::MemoryConfig> &memoryConfig) + : bufferType(bufferType), layout(layout), dataType(dataType), + memoryConfig(memoryConfig) {} + + bool isOnHost() const { + return bufferType == ::ttnn::BufferType::SYSTEM_MEMORY; + } + bool isOnDevice() const { return !isOnHost(); } + + bool isTilized() const { return layout == ::ttnn::Layout::TILE; } +}; + +class LayoutConverter { +public: + LayoutDesc inputDesc; + LayoutDesc outputDesc; + bool shouldTilize = false; + bool shouldUntilize = false; + bool shouldTypecast = false; + bool shouldToDevice = false; + bool shouldToMemoryConfig = false; + bool shouldFromDevice = false; + + LayoutConverter(const LayoutDesc &inputDesc, const LayoutDesc &outputDesc); + ::ttnn::Tensor convertTensorLayout(const ::ttnn::Tensor &input, + std::optional targetDevice); + +private: + ::ttnn::Tensor toLayoutIfNeeded(const ::ttnn::Tensor &input); + ::ttnn::Tensor typecastIfNeeded(const ::ttnn::Tensor &input); + ::ttnn::Tensor toDeviceIfNeeded(const ::ttnn::Tensor &input, + std::optional targetDevice, + bool force = false); + ::ttnn::Tensor toMemoryConfigIfNeeded(const ::ttnn::Tensor &input); + ::ttnn::Tensor fromDeviceIfNeeded(const ::ttnn::Tensor &input); + + ::ttnn::Tensor + handleHostInputNoLayoutNoTypecast(const ::ttnn::Tensor &input, + std::optional targetDevice); + ::ttnn::Tensor + handleHostInputLayoutNoTypecast(const ::ttnn::Tensor &input, + std::optional targetDevice); + ::ttnn::Tensor + handleHostInputNoLayoutTypecast(const ::ttnn::Tensor &input, + std::optional targetDevice); + ::ttnn::Tensor + handleHostInputLayoutTypecast(const ::ttnn::Tensor &input, + std::optional targetDevice); + ::ttnn::Tensor + convertHostTensorLayout(const ::ttnn::Tensor &input, + std::optional targetDevice); + + ::ttnn::Tensor + handleDeviceInputNoLayoutNoTypecast(const ::ttnn::Tensor &input); + ::ttnn::Tensor handleDeviceInputLayoutNoTypecast(const ::ttnn::Tensor &input); + ::ttnn::Tensor handleDeviceInputNoLayoutTypecast(const ::ttnn::Tensor &input); + ::ttnn::Tensor handleDeviceInputLayoutTypecast(const ::ttnn::Tensor &input); + ::ttnn::Tensor convertDeviceTensorLayout(const ::ttnn::Tensor &input); +}; + class ProgramTensorPool { public: - ProgramTensorPool(const TensorMap &liveTensors, - const std::unordered_set &programInputs, - const std::unordered_set &programOutputs) + ProgramTensorPool( + const std::unordered_map &liveTensors, + const std::vector &programInputs, + const std::vector &programOutputs) : programInputs(programInputs), programOutputs(programOutputs), liveTensors(liveTensors) {} ProgramTensorPool(const ProgramTensorPool &) = delete; @@ -25,72 +95,38 @@ class ProgramTensorPool { ProgramTensorPool(ProgramTensorPool &&) = default; ProgramTensorPool &operator=(ProgramTensorPool &&) = default; - auto try_emplace(std::uint32_t globalId, const ::ttnn::Tensor &tensor) { - auto it = liveTensors.find(globalId); - if (it != liveTensors.end()) { - return std::make_pair(it, false); - } - assert(!intermedTensors.contains(globalId)); - intermedTensors.try_emplace(globalId, tensor); - return liveTensors.try_emplace(globalId, &intermedTensors.at(globalId)); - } + std::pair::iterator, bool> + try_emplace(std::uint32_t globalId, const ::ttnn::Tensor &tensor); - auto insert_or_assign(std::uint32_t globalId, const ::ttnn::Tensor &tensor) { - intermedTensors.insert_or_assign(globalId, tensor); - return liveTensors.insert_or_assign(globalId, - &intermedTensors.at(globalId)); - } + std::pair::iterator, bool> + insert_or_assign(std::uint32_t globalId, const ::ttnn::Tensor &tensor); - ::ttnn::Tensor &at(std::uint32_t globalId) { - assert(liveTensors.contains(globalId)); - return *liveTensors.at(globalId); - } + ::ttnn::Tensor &at(std::uint32_t globalId); - const ::ttnn::Tensor &at(std::uint32_t globalId) const { - assert(liveTensors.contains(globalId)); - return *liveTensors.at(globalId); - } + const ::ttnn::Tensor &at(std::uint32_t globalId) const; - size_t erase(std::uint32_t globalId) { - assert(liveTensors.contains(globalId) && - intermedTensors.contains(globalId)); - intermedTensors.erase(globalId); - return liveTensors.erase(globalId); - } + size_t erase(std::uint32_t globalId); - void copyTensorToUserOutput(std::uint32_t outputGlobalId, - const ::ttnn::Tensor &srcTensor) { - assert(liveTensors.contains(outputGlobalId)); - assert(isUserOutput(outputGlobalId)); - ::ttnn::Tensor &outputTensor = *liveTensors.at(outputGlobalId); - void *src = ::tt::tt_metal::get_raw_host_data_ptr(srcTensor); - void *dst = ::tt::tt_metal::get_raw_host_data_ptr(outputTensor); - size_t size = outputTensor.volume() * outputTensor.element_size(); - std::memcpy(dst, src, size); - } + std::vector gatherOutputTensors(); bool contains(std::uint32_t globalId) const { return liveTensors.contains(globalId); } - bool isUserOutput(std::uint32_t globalId) const { - return programOutputs.contains(globalId); - } - - const std::unordered_set &getProgramInputs() const { + const std::vector &getProgramInputs() const { return programInputs; } - const std::unordered_set &getProgramOutputs() const { + const std::vector &getProgramOutputs() const { return programOutputs; } private: - std::unordered_set programInputs; - std::unordered_set programOutputs; + std::vector programInputs; + std::vector programOutputs; // A superset of intermedTensors, containing pointers to all tensors created - // by the program and the input/output tensors passed in by the user - TensorMap liveTensors; + // by the program and the input tensors passed in by the user + std::unordered_map liveTensors; // A subset of liveTensors, containing values of any intermediate tensors // created by the program @@ -99,15 +135,11 @@ class ProgramTensorPool { class ProgramContext { public: - ProgramContext(const TensorMap &liveTensors, - const std::unordered_set &programInputs, - const std::unordered_set &programOutputs, - ::ttnn::MeshDevice *parentMesh) - : tensorPool( - ProgramTensorPool(liveTensors, programInputs, programOutputs)), - parentMesh(parentMesh) { - assert(parentMesh && "Parent mesh cannot be null"); - } + ProgramContext( + const std::unordered_map &liveTensors, + const std::vector &programInputs, + const std::vector &programOutputs, + ::ttnn::MeshDevice *parentMesh); ProgramContext(const ProgramContext &) = delete; ProgramContext &operator=(const ProgramContext &) = delete; ProgramContext(ProgramContext &&) = default; @@ -125,42 +157,17 @@ class ProgramContext { // // Sub Mesh Operations // - void addSubMesh(uint32_t meshId, - std::shared_ptr<::ttnn::MeshDevice> subMesh) { - auto [it, inserted] = subMeshes.try_emplace(meshId, subMesh); - assert(inserted && "Submesh already exists"); - } + void addSubMesh(uint32_t meshId, std::shared_ptr<::ttnn::MeshDevice> subMesh); - ::ttnn::MeshDevice &getSubMesh(uint32_t meshId) { - assert(subMeshes.contains(meshId)); - return *subMeshes.at(meshId); - } + ::ttnn::MeshDevice &getSubMesh(uint32_t meshId); - size_t subMeshSize(uint32_t meshId) const { - assert(subMeshes.contains(meshId)); - return subMeshes.at(meshId)->num_devices(); - } + size_t subMeshSize(uint32_t meshId) const; - ::ttnn::Device &getDeviceFromSubMesh(uint32_t meshId, int physicalDeviceId) { - assert(subMeshes.contains(meshId)); - auto &subMesh = *subMeshes.at(meshId); - return *subMesh.get_device(physicalDeviceId); - } + ::ttnn::Device &getDeviceFromSubMesh(uint32_t meshId, int physicalDeviceId); - ::ttnn::Device &getDeviceIndexFromSubMesh(uint32_t meshId, int deviceIndex) { - assert(subMeshes.contains(meshId)); - auto &subMesh = *subMeshes.at(meshId); - return *subMesh.get_device_index(deviceIndex); - } + ::ttnn::Device &getDeviceIndexFromSubMesh(uint32_t meshId, int deviceIndex); - DeviceVariant getTargetDevice(uint32_t meshId) { - assert(subMeshes.contains(meshId)); - auto &subMesh = *subMeshes.at(meshId); - if (subMesh.num_devices() == 1) { - return std::ref(*subMesh.get_device_index(0)); - } - return std::ref(subMesh); - } + DeviceVariant getTargetDevice(uint32_t meshId); // // Tensor Pool Operations diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.cpp b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.cpp new file mode 100644 index 0000000000..fa8aa82ed2 --- /dev/null +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.cpp @@ -0,0 +1,222 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt/runtime/ttnn/utils.h" +#include "tt/runtime/detail/logger.h" + +namespace tt::runtime::ttnn::utils { + +// TODO (bug #701) +// Currently the memory layout/location in flatbuffer is incorrect +// These methods are workarounds for operations such that we query the info +// directly from the TTNN tensor. Ideally, we should be able to get all of this +// info directly from the flatbuffer +bool isOnHost(const ::ttnn::StorageType &storageType) { + return storageType == ::tt::tt_metal::StorageType::BORROWED or + storageType == ::tt::tt_metal::StorageType::OWNED or + storageType == ::tt::tt_metal::StorageType::MULTI_DEVICE_HOST; +} + +bool isOnDevice(const ::ttnn::StorageType &storageType) { + return storageType == ::tt::tt_metal::StorageType::DEVICE or + storageType == ::tt::tt_metal::StorageType::MULTI_DEVICE; +} + +bool isValidTileShape(const ::tt::target::Dim2d *shape) { + return (shape->x() == 1 and shape->y() == 1) or + (shape->x() == 32 and shape->y() == 32); +} + +::ttnn::DataType toTTNNDataType(::tt::target::DataType dataType) { + switch (dataType) { + case ::tt::target::DataType::Float32: + return ::ttnn::DataType::FLOAT32; + case ::tt::target::DataType::BFloat16: + return ::ttnn::DataType::BFLOAT16; + case ::tt::target::DataType::BFP_BFloat8: + return ::ttnn::DataType::BFLOAT8_B; + case ::tt::target::DataType::BFP_BFloat4: + return ::ttnn::DataType::BFLOAT4_B; + case ::tt::target::DataType::UInt32: + return ::ttnn::DataType::UINT32; + case ::tt::target::DataType::UInt16: + return ::ttnn::DataType::UINT16; + + default: + LOG_FATAL("Unsupported data type"); + } +} + +::tt::target::DataType fromTTNNDataType(::ttnn::DataType dataType) { + switch (dataType) { + case ::ttnn::DataType::FLOAT32: + return ::tt::target::DataType::Float32; + case ::ttnn::DataType::BFLOAT16: + return ::tt::target::DataType::BFloat16; + case ::ttnn::DataType::BFLOAT8_B: + return ::tt::target::DataType::BFP_BFloat8; + case ::ttnn::DataType::BFLOAT4_B: + return ::tt::target::DataType::BFP_BFloat4; + case ::ttnn::DataType::UINT32: + return ::tt::target::DataType::UInt32; + case ::ttnn::DataType::UINT16: + return ::tt::target::DataType::UInt16; + + default: + LOG_FATAL("Unsupported data type"); + } +} + +::ttnn::Layout toTTNNLayout(::tt::target::TensorLayout layout) { + switch (layout) { + case ::tt::target::TensorLayout::Tile: + return ::ttnn::Layout::TILE; + case ::tt::target::TensorLayout::RowMajor: + return ::ttnn::Layout::ROW_MAJOR; + default: + LOG_FATAL("Unsupported layout"); + } +} + +::ttnn::TensorMemoryLayout +toTTNNTensorMemoryLayout(::tt::target::TensorMemoryLayout tensorMemoryLayout) { + + switch (tensorMemoryLayout) { + case ::tt::target::TensorMemoryLayout::Interleaved: + return ::ttnn::TensorMemoryLayout::INTERLEAVED; + case ::tt::target::TensorMemoryLayout::SingleBank: + return ::ttnn::TensorMemoryLayout::SINGLE_BANK; + case ::tt::target::TensorMemoryLayout::HeightSharded: + return ::ttnn::TensorMemoryLayout::HEIGHT_SHARDED; + case ::tt::target::TensorMemoryLayout::WidthSharded: + return ::ttnn::TensorMemoryLayout::WIDTH_SHARDED; + case ::tt::target::TensorMemoryLayout::BlockSharded: + return ::ttnn::TensorMemoryLayout::BLOCK_SHARDED; + case ::tt::target::TensorMemoryLayout::None: + LOG_FATAL("Unsupported tensor memory layout None"); + } +} + +// This method will be deprecated in favor of method below +// +::tt::tt_metal::BufferType +toTTNNBufferType(::tt::target::MemorySpace memorySpace) { + switch (memorySpace) { + case ::tt::target::MemorySpace::System: + case ::tt::target::MemorySpace::SystemMMIO: + return ::tt::tt_metal::BufferType::SYSTEM_MEMORY; + case ::tt::target::MemorySpace::DeviceDRAM: + return ::tt::tt_metal::BufferType::DRAM; + case ::tt::target::MemorySpace::DeviceL1: + return ::tt::tt_metal::BufferType::L1; + } +} + +// Prefer to use this method +// +::ttnn::BufferType toTTNNBufferType(::tt::target::BufferType bufferType) { + + switch (bufferType) { + case ::tt::target::BufferType::DRAM: + return ::ttnn::BufferType::DRAM; + case ::tt::target::BufferType::L1: + return ::ttnn::BufferType::L1; + case ::tt::target::BufferType::SystemMemory: + return ::ttnn::BufferType::SYSTEM_MEMORY; + case ::tt::target::BufferType::L1Small: + return ::ttnn::BufferType::L1_SMALL; + case ::tt::target::BufferType::Trace: + return ::ttnn::BufferType::TRACE; + } +}; + +std::vector +toShapeFromFBShape(const flatbuffers::Vector &vec) { + return std::vector(vec.begin(), vec.end()); +} + +::ttnn::Layout +inferLayoutFromTileShape(const ::tt::target::TensorRef *tensorRef) { + const ::tt::target::Dim2d *tileShape = + tensorRef->desc()->layout()->memory_desc()->tile_shape(); + LOG_ASSERT(isValidTileShape(tileShape)); + if (tileShape->x() == 1 and tileShape->y() == 1) { + return ::ttnn::Layout::ROW_MAJOR; + } + return ::ttnn::Layout::TILE; +} + +CoreRangeSet +toCoreRangeSet(const ::flatbuffers::Vector + *coreRangeSet) { + std::set coreRanges; + for (::tt::target::Dim2dRange const *coreRange : *coreRangeSet) { + CoreCoord start(coreRange->loc().x(), coreRange->loc().y()); + // End is inclusive + CoreCoord end(coreRange->loc().x() + coreRange->size().x() - 1, + coreRange->loc().y() + coreRange->size().y() - 1); + + coreRanges.emplace(start, end); + } + return CoreRangeSet(coreRanges); +} + +::tt::tt_metal::MemoryConfig +createMemoryConfig(const ::tt::target::TensorRef *tensorRef) { + const ::tt::target::LayoutDesc *layout = tensorRef->desc()->layout(); + const ::tt::target::TensorMemoryLayout targetMemoryLayout = + layout->memory_desc()->memory_layout(); + const ::tt::target::MemorySpace targetMemorySpace = + layout->memory_desc()->memory_space(); + const ::flatbuffers::Vector + *targetCoreRangeSet = layout->core_range_set(); + const ::flatbuffers::Vector *targetShardShape = + layout->memory_desc()->shape(); + const ::tt::target::Dim2d *tileShape = layout->memory_desc()->tile_shape(); + + LOG_ASSERT(targetCoreRangeSet->size() == 1, + "Currently only single core range/grid is supported"); + + LOG_ASSERT(targetShardShape->size() == 2, + "Only 2D shard shape is supported in TTNN backend"); + + LOG_ASSERT(::tt::runtime::ttnn::utils::isValidTileShape(tileShape), + "Invalid tile shape"); + + CoreRangeSet ttnnCoreRangeSet = toCoreRangeSet(targetCoreRangeSet); + std::array ttnnShardShape; + std::copy(targetShardShape->begin(), targetShardShape->end(), + ttnnShardShape.begin()); + + ttnnShardShape[0] *= tileShape->y(); + ttnnShardShape[1] *= tileShape->x(); + + ::tt::tt_metal::TensorMemoryLayout ttnnMemLayout = + toTTNNTensorMemoryLayout(targetMemoryLayout); + + ::tt::tt_metal::BufferType ttnnBufferType = + toTTNNBufferType(targetMemorySpace); + + ::tt::tt_metal::ShardSpec shardSpec( + ttnnCoreRangeSet, ttnnShardShape, + ::tt::tt_metal::ShardOrientation::ROW_MAJOR, false); + + std::optional<::tt::tt_metal::ShardSpec> shardSpecOpt = + ttnnMemLayout == tt_metal::TensorMemoryLayout::INTERLEAVED + ? std::nullopt + : std::make_optional(shardSpec); + + ::tt::tt_metal::MemoryConfig memoryConfig{.memory_layout = ttnnMemLayout, + .buffer_type = ttnnBufferType, + .shard_spec = shardSpecOpt}; + return memoryConfig; +} + +Tensor createRuntimeTensorFromTTNN(const ::ttnn::Tensor &tensor) { + auto tensorPtr = std::make_shared<::ttnn::Tensor>(tensor); + return Tensor(std::static_pointer_cast(tensorPtr), nullptr, + DeviceRuntime::TTNN); +} + +} // namespace tt::runtime::ttnn::utils diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h index ca50ad58b3..353195b8df 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h @@ -6,126 +6,50 @@ #define TT_RUNTIME_TTNN_UTILS_H #include "flatbuffers/vector.h" +#include "tt/runtime/detail/ttnn.h" #include "ttmlir/Target/Common/types_generated.h" #include "ttmlir/Target/TTNN/Target.h" -#include "ttnn/types.hpp" namespace tt::runtime::ttnn::utils { -inline bool isValidTileShape(const ::tt::target::Dim2d *shape) { - return (shape->x() == 1 and shape->y() == 1) or - (shape->x() == 32 and shape->y() == 32); -} - -inline ::ttnn::DataType toTTNNDataType(::tt::target::DataType dataType) { - switch (dataType) { - case ::tt::target::DataType::Float32: - return ::ttnn::DataType::FLOAT32; - case ::tt::target::DataType::BFloat16: - return ::ttnn::DataType::BFLOAT16; - case ::tt::target::DataType::BFP_BFloat8: - return ::ttnn::DataType::BFLOAT8_B; - case ::tt::target::DataType::BFP_BFloat4: - return ::ttnn::DataType::BFLOAT4_B; - case ::tt::target::DataType::UInt32: - return ::ttnn::DataType::UINT32; - case ::tt::target::DataType::UInt16: - return ::ttnn::DataType::UINT16; - - default: - throw std::runtime_error("Unsupported data type"); - } -} - -inline ::tt::target::DataType fromTTNNDataType(::ttnn::DataType dataType) { - switch (dataType) { - case ::ttnn::DataType::FLOAT32: - return ::tt::target::DataType::Float32; - case ::ttnn::DataType::BFLOAT16: - return ::tt::target::DataType::BFloat16; - case ::ttnn::DataType::BFLOAT8_B: - return ::tt::target::DataType::BFP_BFloat8; - case ::ttnn::DataType::BFLOAT4_B: - return ::tt::target::DataType::BFP_BFloat4; - case ::ttnn::DataType::UINT32: - return ::tt::target::DataType::UInt32; - case ::ttnn::DataType::UINT16: - return ::tt::target::DataType::UInt16; - - default: - throw std::runtime_error("Unsupported data type"); - } -} - -inline ::ttnn::Layout toTTNNLayout(::tt::target::TensorLayout layout) { - switch (layout) { - case ::tt::target::TensorLayout::Tile: - return ::ttnn::Layout::TILE; - case ::tt::target::TensorLayout::RowMajor: - return ::ttnn::Layout::ROW_MAJOR; - default: - throw std::runtime_error("Unsupported layout"); - } -} - -inline ::ttnn::TensorMemoryLayout -toTTNNTensorMemoryLayout(::tt::target::TensorMemoryLayout tensorMemoryLayout) { - - switch (tensorMemoryLayout) { - case ::tt::target::TensorMemoryLayout::Interleaved: - return ::ttnn::TensorMemoryLayout::INTERLEAVED; - case ::tt::target::TensorMemoryLayout::SingleBank: - return ::ttnn::TensorMemoryLayout::SINGLE_BANK; - case ::tt::target::TensorMemoryLayout::HeightSharded: - return ::ttnn::TensorMemoryLayout::HEIGHT_SHARDED; - case ::tt::target::TensorMemoryLayout::WidthSharded: - return ::ttnn::TensorMemoryLayout::WIDTH_SHARDED; - case ::tt::target::TensorMemoryLayout::BlockSharded: - return ::ttnn::TensorMemoryLayout::BLOCK_SHARDED; - case ::tt::target::TensorMemoryLayout::None: - assert(false && - "Unsupported tensor memory layout TensorMemoryLayout::None"); - } -} +bool isOnHost(const ::ttnn::StorageType &storageType); + +bool isOnDevice(const ::ttnn::StorageType &storageType); + +bool isValidTileShape(const ::tt::target::Dim2d *shape); + +::ttnn::DataType toTTNNDataType(::tt::target::DataType dataType); + +::tt::target::DataType fromTTNNDataType(::ttnn::DataType dataType); + +::ttnn::Layout toTTNNLayout(::tt::target::TensorLayout layout); + +::ttnn::TensorMemoryLayout +toTTNNTensorMemoryLayout(::tt::target::TensorMemoryLayout tensorMemoryLayout); // This method will be deprecated in favor of method below // -inline ::tt::tt_metal::BufferType -toTTNNBufferType(::tt::target::MemorySpace memorySpace) { - switch (memorySpace) { - case ::tt::target::MemorySpace::System: - case ::tt::target::MemorySpace::SystemMMIO: - return ::tt::tt_metal::BufferType::SYSTEM_MEMORY; - case ::tt::target::MemorySpace::DeviceDRAM: - return ::tt::tt_metal::BufferType::DRAM; - case ::tt::target::MemorySpace::DeviceL1: - return ::tt::tt_metal::BufferType::L1; - } -} +::tt::tt_metal::BufferType +toTTNNBufferType(::tt::target::MemorySpace memorySpace); // Prefer to use this method // -inline ::ttnn::BufferType -toTTNNBufferType(::tt::target::BufferType bufferType) { - - switch (bufferType) { - case ::tt::target::BufferType::DRAM: - return ::ttnn::BufferType::DRAM; - case ::tt::target::BufferType::L1: - return ::ttnn::BufferType::L1; - case ::tt::target::BufferType::SystemMemory: - return ::ttnn::BufferType::SYSTEM_MEMORY; - case ::tt::target::BufferType::L1Small: - return ::ttnn::BufferType::L1_SMALL; - case ::tt::target::BufferType::Trace: - return ::ttnn::BufferType::TRACE; - } -}; - -inline std::vector -toShapeFromFBShape(const flatbuffers::Vector &vec) { - return std::vector(vec.begin(), vec.end()); -} +::ttnn::BufferType toTTNNBufferType(::tt::target::BufferType bufferType); + +std::vector +toShapeFromFBShape(const flatbuffers::Vector &vec); + +::ttnn::Layout +inferLayoutFromTileShape(const ::tt::target::TensorRef *tensorRef); + +CoreRangeSet +toCoreRangeSet(const ::flatbuffers::Vector + *coreRangeSet); + +::tt::tt_metal::MemoryConfig +createMemoryConfig(const ::tt::target::TensorRef *tensorRef); + +Tensor createRuntimeTensorFromTTNN(const ::ttnn::Tensor &tensor); } // namespace tt::runtime::ttnn::utils diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index 4edc4780b9..d7d9357b5f 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -5,6 +5,7 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/creation/arange.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/full.cpp ${CMAKE_CURRENT_SOURCE_DIR}/data_movement/concat.cpp @@ -18,6 +19,8 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary_composite.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eltwise/ternary/ternary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/embedding/embedding.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/kv_cache/fill_cache.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/kv_cache/update_cache.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout/to_device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout/from_device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout/to_layout.cpp @@ -45,12 +48,13 @@ target_include_directories(TTRuntimeTTNNOps PUBLIC ${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/operations/include ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common ) + target_include_directories(TTRuntimeTTNNOps SYSTEM PUBLIC "$") -target_link_libraries(TTRuntimeTTNNOps PUBLIC TTNN_LIBRARY) +target_link_libraries(TTRuntimeTTNNOps PUBLIC TTNN_LIBRARY TTRuntimeTTNNHelpers) if (TT_RUNTIME_ENABLE_PERF_TRACE) target_link_libraries(TTRuntimeTTNNOps PUBLIC TRACY_LIBRARY) endif() -add_dependencies(TTRuntimeTTNNOps TTNN_LIBRARY tt-metal FBS_GENERATION) +add_dependencies(TTRuntimeTTNNOps TTNN_LIBRARY tt-metal FBS_GENERATION TTRuntimeTTNNHelpers) diff --git a/runtime/lib/ttnn/operations/ccl/all_gather.cpp b/runtime/lib/ttnn/operations/ccl/all_gather.cpp index 37bf7427bf..eee27e7bab 100644 --- a/runtime/lib/ttnn/operations/ccl/all_gather.cpp +++ b/runtime/lib/ttnn/operations/ccl/all_gather.cpp @@ -5,6 +5,7 @@ #include "all_gather.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::ccl { void run(const ::tt::target::ttnn::AllGatherOp *op, ProgramContext &context) { @@ -13,7 +14,7 @@ void run(const ::tt::target::ttnn::AllGatherOp *op, ProgramContext &context) { int32_t dim = op->dim(); int32_t num_links = op->num_links(); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ::ttnn::all_gather(input, dim, num_links, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); diff --git a/runtime/lib/ttnn/operations/conv/conv2d.cpp b/runtime/lib/ttnn/operations/conv/conv2d.cpp index e6670c1131..dfc60d4445 100644 --- a/runtime/lib/ttnn/operations/conv/conv2d.cpp +++ b/runtime/lib/ttnn/operations/conv/conv2d.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" #include "ttmlir/Target/TTNN/program_generated.h" #include "ttnn/types.hpp" @@ -20,10 +21,11 @@ void run(const ::tt::target::ttnn::Conv2dOp *op, ProgramContext &context) { std::optional<::ttnn::Tensor> bias = op->bias() ? std::make_optional(tensorPool.at(op->bias()->global_id())) : std::nullopt; - auto config = ::ttnn::operations::conv::conv2d::Conv2dConfig(); + auto config = ::ttnn::operations::conv::Conv2dConfig(); config.dtype = utils::getDataType(op->input()); config.weights_dtype = utils::getDataType(op->weight()); - ::ttnn::MemoryConfig outMemConfig = utils::createMemoryConfig(op->out()); + ::ttnn::MemoryConfig outMemConfig = + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); DeviceVariant targetDevice = context.getTargetDevice(op->device()->global_id()); ::ttnn::Tensor out = std::visit( diff --git a/runtime/lib/ttnn/operations/creation/arange.cpp b/runtime/lib/ttnn/operations/creation/arange.cpp new file mode 100644 index 0000000000..f51937462a --- /dev/null +++ b/runtime/lib/ttnn/operations/creation/arange.cpp @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "arange.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" +#include "ttnn/types.hpp" + +#include +#include + +namespace tt::runtime::ttnn::operations::creation { +void run(const ::tt::target::ttnn::ArangeOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + ::ttnn::DataType dtype = + ::ttnn::DataType::BFLOAT16; // Default in arange implementation + std::optional> device = std::nullopt; + ::ttnn::MemoryConfig memoryConfig = + ::ttnn::DRAM_MEMORY_CONFIG; // Default in arange implementation + + if (op->dtype()) { + dtype = ::tt::runtime::ttnn::utils::toTTNNDataType(*(op->dtype())); + } + + if (op->memcfg()) { + memoryConfig = utils::createMemoryConfig(op->memcfg(), op->out()); + } + + if (op->device()) { + // ttnn::arange supports no device (host) and single device + DeviceVariant targetDevice = + context.getTargetDevice(op->device()->global_id()); + + LOG_ASSERT(std::holds_alternative>( + targetDevice), + "ttnn::arange does not support MeshDevice."); + device = std::make_optional( + std::get>(targetDevice)); + } + ::ttnn::Tensor out = ::ttnn::arange(op->start(), op->end(), op->step(), dtype, + device, memoryConfig); + + tensorPool.insert_or_assign(op->out()->global_id(), out); +} +} // namespace tt::runtime::ttnn::operations::creation diff --git a/runtime/lib/ttnn/operations/creation/arange.h b/runtime/lib/ttnn/operations/creation/arange.h new file mode 100644 index 0000000000..157ee2dc61 --- /dev/null +++ b/runtime/lib/ttnn/operations/creation/arange.h @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef RUNTIME_LIB_TTNN_OPERATIONS_CREATION_ARANGE_H +#define RUNTIME_LIB_TTNN_OPERATIONS_CREATION_ARANGE_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::creation { + +void run(const ::tt::target::ttnn::ArangeOp *op, ProgramContext &context); + +} // namespace tt::runtime::ttnn::operations::creation + +#endif diff --git a/runtime/lib/ttnn/operations/creation/empty.cpp b/runtime/lib/ttnn/operations/creation/empty.cpp index bed68e0f14..85eacef23d 100644 --- a/runtime/lib/ttnn/operations/creation/empty.cpp +++ b/runtime/lib/ttnn/operations/creation/empty.cpp @@ -24,11 +24,6 @@ struct EmptyTensorConfig { dtype(::tt::runtime::ttnn::operations::utils::getDataType(op->out())), numShards(op->num_shards()), strategy(op->strategy()) { layout = ::tt::runtime::ttnn::utils::toTTNNLayout(op->layout()); - // TODO(bug #582): ttnn::empty doesn't work properly with tile layout, - // using ROW_MAJOR until we fix it - if (workaround::Env::get().emptyOpForceRowMajor) { - layout = ::ttnn::Layout::ROW_MAJOR; - } if (op->device()) { LOG_ASSERT(op->memcfg(), "Memory config must be provided when device is provided"); @@ -62,11 +57,12 @@ createEmptyOnMultiDevice(ProgramContext &context, EmptyTensorConfig &config, ::tt::tt_metal::DistributedTensorConfig strategy = config.distributedTensorConfig(); std::vector<::ttnn::Tensor> tensorShards; - tensorShards.resize(config.numShards); - std::generate_n( - tensorShards.begin(), config.numShards, [&config]() -> ::ttnn::Tensor { - return ::ttnn::zeros(config.shape, config.dtype, config.layout); - }); + tensorShards.reserve(config.numShards); + std::generate_n(std::back_inserter(tensorShards), config.numShards, + [&config]() -> ::ttnn::Tensor { + return ::ttnn::zeros(config.shape, config.dtype, + config.layout); + }); ::ttnn::Tensor out = ::ttnn::distributed::api::create_multi_device_tensor( tensorShards, ::tt::tt_metal::StorageType::MULTI_DEVICE_HOST, strategy); if (deviceRef) { @@ -101,6 +97,6 @@ void run(const ::tt::target::ttnn::EmptyOp *op, ProgramContext &context) { } else { LOG_FATAL("Unsupported num shards"); } - utils::updateTensorPool(tensorPool, out, op->out()->global_id()); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::creation diff --git a/runtime/lib/ttnn/operations/creation/full.cpp b/runtime/lib/ttnn/operations/creation/full.cpp index 6a224f935d..b8536e0a86 100644 --- a/runtime/lib/ttnn/operations/creation/full.cpp +++ b/runtime/lib/ttnn/operations/creation/full.cpp @@ -26,24 +26,10 @@ struct FullTensorConfig { fillValue(op->fill_value()), numShards(op->num_shards()), strategy(op->strategy()) { - layout = utils::inferLayoutFromTileShape(op->out()); - - // TODO(bug #272), determine correct layout by tile shape in the future - // currently tile shape is not set correctly, so as a workaround, hardcode - // layout - if (workaround::Env::get().ignoreTileShape) { - layout = ::ttnn::Layout::TILE; - } - - // TODO(bug #582): ttnn::empty doesn't work properly with tile layout, - // using ROW_MAJOR until we fix it - if (workaround::Env::get().fullOpForceRowMajor) { - layout = ::ttnn::Layout::ROW_MAJOR; - } + layout = ::tt::runtime::ttnn::utils::inferLayoutFromTileShape(op->out()); if (!utils::inSystemMemory(op->out())) { - memoryConfig = - ::tt::runtime::ttnn::operations::utils::createMemoryConfig(op->out()); + memoryConfig = ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); } validate(); } @@ -72,8 +58,8 @@ createFullOnMultiDevice(ProgramContext &context, FullTensorConfig &config, ::tt::tt_metal::DistributedTensorConfig strategy = config.distributedTensorConfig(); std::vector<::ttnn::Tensor> tensorShards; - tensorShards.resize(config.numShards); - std::generate_n(tensorShards.begin(), config.numShards, + tensorShards.reserve(config.numShards); + std::generate_n(std::back_inserter(tensorShards), config.numShards, [&config]() -> ::ttnn::Tensor { return ::ttnn::full(config.shape, config.fillValue, config.dtype, config.layout); @@ -116,6 +102,6 @@ void run(const ::tt::target::ttnn::FullOp *op, ProgramContext &context) { } else { LOG_FATAL("Unsupported num shards"); } - utils::updateTensorPool(tensorPool, out, op->out()->global_id()); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::creation diff --git a/runtime/lib/ttnn/operations/data_movement/transpose.cpp b/runtime/lib/ttnn/operations/data_movement/transpose.cpp index ef8dcf1b13..c86c0ee10a 100644 --- a/runtime/lib/ttnn/operations/data_movement/transpose.cpp +++ b/runtime/lib/ttnn/operations/data_movement/transpose.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::data_movement { void run(const ::tt::target::ttnn::TransposeOp *op, ProgramContext &context) { @@ -15,7 +16,7 @@ void run(const ::tt::target::ttnn::TransposeOp *op, ProgramContext &context) { int32_t dim0 = op->dim0(); int32_t dim1 = op->dim1(); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ::ttnn::transpose(in, dim0, dim1, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); } diff --git a/runtime/lib/ttnn/operations/deletion/deallocate.cpp b/runtime/lib/ttnn/operations/deletion/deallocate.cpp index 6204945b34..e871a9ea64 100644 --- a/runtime/lib/ttnn/operations/deletion/deallocate.cpp +++ b/runtime/lib/ttnn/operations/deletion/deallocate.cpp @@ -11,13 +11,6 @@ void run(const ::tt::target::ttnn::DeallocateOp *op, ProgramContext &context) { ::ttnn::Tensor &tensor = tensorPool.at(op->in()->global_id()); DEBUG_ASSERT(tensor.is_allocated()); ::ttnn::deallocate(tensor, op->force()); - - // The tensor should be deallocated after the deallocate call. - // Still this assert may be hit in the future for multidevice/async ttnn - // support. In that case, we will reevaluate the assert/dealloc behaviour and - // adjust it accordingly. - // - DEBUG_ASSERT(!tensor.is_allocated()); tensorPool.erase(op->in()->global_id()); } } // namespace tt::runtime::ttnn::operations::deletion diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp index 5913971192..ff47bdcdd8 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/binary/utils.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" #include "ttnn/operations/eltwise/binary/binary_composite.hpp" namespace tt::runtime::ttnn::operations::binary { @@ -26,7 +27,7 @@ static void runEltwiseBinaryOp( ::ttnn::DataType outputDataType = utils::getDataType(op->out()); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputDataType, outputMemoryConfig, std::nullopt, std::nullopt, std::nullopt); diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp index 2a05d6246f..921b542ed2 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/binary/utils.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::binary::composite { @@ -20,7 +21,7 @@ static void runEltwiseBinaryCompositeOp( getEltwiseBinaryOpInputTensors(op, tensorPool, &lhs, &rhs); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); @@ -41,6 +42,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseBinaryCompositeOp(op, tensorPool, ::ttnn::remainder); break; } + case ::tt::target::ttnn::EltwiseOpType::Scatter: { + runEltwiseBinaryCompositeOp(op, tensorPool, ::ttnn::scatter); + break; + } default: LOG_FATAL("Unsupported Eltwise Binary Composite operation"); } diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h index 9be8bc6b7e..bd497fe98c 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h @@ -15,6 +15,7 @@ inline bool isBinaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) { case ::tt::target::ttnn::EltwiseOpType::Maximum: case ::tt::target::ttnn::EltwiseOpType::Minimum: case ::tt::target::ttnn::EltwiseOpType::Remainder: + case ::tt::target::ttnn::EltwiseOpType::Scatter: return true; default: return false; diff --git a/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp b/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp index 6afde5d663..44f1413898 100644 --- a/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp @@ -22,7 +22,7 @@ static void runEltwiseTernaryWhereOp( getEltwiseTernaryOpInputTensors(op, tensorPool, &first, &second, &third); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*first, *second, *third, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp index 5a09b43a9f..d24dc24f8d 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/unary/utils.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" #include "ttmlir/Target/TTNN/program_generated.h" #include "ttnn/operations/copy.hpp" @@ -22,7 +23,7 @@ static void runEltwiseUnaryOp( getEltwiseUnaryOpInputTensor(op, tensorPool, &in); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig, std::nullopt); tensorPool.insert_or_assign(op->out()->global_id(), out); @@ -39,7 +40,7 @@ static void runEltwiseUnaryWithFastAndApproximateModeOp( getEltwiseUnaryOpInputTensor(op, tensorPool, &in); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*in, false /* parameter */, outputMemoryConfig, std::nullopt); @@ -56,7 +57,7 @@ static void runEltwiseUnaryWithFloatParameterOp( float parameter = op->params_as_EltwiseOpWithFloatParams()->parameter(); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*in, parameter, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); } @@ -126,6 +127,14 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseUnaryOp(op, tensorPool, ::ttnn::sign); break; } + case ::tt::target::ttnn::EltwiseOpType::Tan: { + runEltwiseUnaryOp(op, tensorPool, ::ttnn::tan); + break; + } + case ::tt::target::ttnn::EltwiseOpType::Tanh: { + runEltwiseUnaryOp(op, tensorPool, ::ttnn::tanh); + break; + } case ::tt::target::ttnn::EltwiseOpType::Exp: { runEltwiseUnaryWithFastAndApproximateModeOp(op, tensorPool, ::ttnn::exp); break; diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp index fd378d5a26..31514f0fe5 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/unary/utils.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" #include "ttnn/operations/eltwise/unary/unary_composite.hpp" namespace tt::runtime::ttnn::operations::unary::composite { @@ -20,27 +21,26 @@ static void runEltwiseUnaryCompositeOp( getEltwiseUnaryOpInputTensor(op, tensorPool, &in); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); } -static void runEltwiseUnaryCompositeClampOP( +static void runEltwiseUnaryCompositeClampOp( const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, - std::function<::ttnn::Tensor(const ::ttnn::Tensor &, float, float, - const ::tt::tt_metal::MemoryConfig &)> - ttnnOp) { + const std::function<::ttnn::Tensor(const ::ttnn::Tensor &, float, float, + const ::tt::tt_metal::MemoryConfig &)> + &ttnnOp) { ::ttnn::Tensor *in = nullptr; getEltwiseUnaryOpInputTensor(op, tensorPool, &in); float min = op->params_as_ClampOpParams()->min(); float max = op->params_as_ClampOpParams()->max(); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*in, min, max, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); - return; } void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { @@ -51,7 +51,7 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { break; } case ::tt::target::ttnn::EltwiseOpType::Clamp: { - runEltwiseUnaryCompositeClampOP(op, tensorPool, ::ttnn::clamp); + runEltwiseUnaryCompositeClampOp(op, tensorPool, ::ttnn::clamp); break; } case ::tt::target::ttnn::EltwiseOpType::Log1p: { diff --git a/runtime/lib/ttnn/operations/embedding/embedding.cpp b/runtime/lib/ttnn/operations/embedding/embedding.cpp index 47b27ca9ac..511d8256de 100644 --- a/runtime/lib/ttnn/operations/embedding/embedding.cpp +++ b/runtime/lib/ttnn/operations/embedding/embedding.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::embedding { void run(const ::tt::target::ttnn::EmbeddingOp *op, ProgramContext &context) { @@ -24,7 +25,7 @@ void run(const ::tt::target::ttnn::EmbeddingOp *op, ProgramContext &context) { auto embeddingsType = ::ttnn::operations::embedding::EmbeddingsType::GENERIC; ::ttnn::DataType outputDataType = utils::getDataType(op->out()); ::ttnn::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ::ttnn::embedding(input, weight, padToken, layout, embeddingsType, outputDataType, outputMemoryConfig); diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp index a54777ab28..f97f71e403 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp @@ -7,6 +7,15 @@ namespace tt::runtime::ttnn::operations::binary { +bool shouldSwapBinaryOperands(const ::tt::target::ttnn::EltwiseOp *op, + ::ttnn::Tensor **lhs, ::ttnn::Tensor **rhs) { + // For scatter, we expect the left-hand side operator to be lesser or equal in + // volume to the right hand side, so we omit the swap. + return (op->type() != ::tt::target::ttnn::EltwiseOpType::Scatter && + workaround::Env::get().swapBinaryOperands && + (*lhs)->volume() < (*rhs)->volume()); +} + void getEltwiseBinaryOpInputTensors(const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, ::ttnn::Tensor **lhs, @@ -21,8 +30,7 @@ void getEltwiseBinaryOpInputTensors(const ::tt::target::ttnn::EltwiseOp *op, // TODO(bug #1124): We're currently swapping the operands for binary ops // in runtime if the lhs operand is smaller (and requires broadcast onto the // rhs operand). We should add this check in the compiler. - if (workaround::Env::get().swapBinaryOperands && - (*lhs)->volume() < (*rhs)->volume()) { + if (shouldSwapBinaryOperands(op, lhs, rhs)) { std::swap(*lhs, *rhs); } } diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp index 435607b87e..60ee2ddc2b 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp @@ -7,25 +7,6 @@ namespace tt::runtime::ttnn::operations::utils { -// TODO (bug #701) -// Currently the memory layout/location in flatbuffer is incorrect -// These methods are workarounds such that we query the info directly from the -// TTNN tensor Ideally, we should be able to get all of this info directly from -// the flatbuffer -bool isOnHost(const ::ttnn::Tensor &tensor) { - // Currently only supports borrowed or owned host storage - return tensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED or - tensor.storage_type() == ::tt::tt_metal::StorageType::OWNED or - tensor.storage_type() == - ::tt::tt_metal::StorageType::MULTI_DEVICE_HOST; -} - -bool isOnDevice(const ::ttnn::Tensor &tensor) { - // Currently only supports single device storage - return tensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE or - tensor.storage_type() == ::tt::tt_metal::StorageType::MULTI_DEVICE; -} - bool isTilized(const ::tt::target::TensorRef *tensorRef) { const ::tt::target::Dim2d *tileShape = tensorRef->desc()->layout()->memory_desc()->tile_shape(); @@ -43,93 +24,11 @@ bool inSystemMemory(const ::tt::target::TensorRef *tensorRef) { targetMemorySpace == ::tt::target::MemorySpace::SystemMMIO; } -void updateTensorPool(ProgramTensorPool &tensorPool, - const ::ttnn::Tensor &tensor, uint32_t outputGlobalId) { - if (tensorPool.isUserOutput(outputGlobalId)) { - tensorPool.copyTensorToUserOutput(outputGlobalId, tensor); - } else { - tensorPool.insert_or_assign(outputGlobalId, tensor); - } -} - ::ttnn::DataType getDataType(const ::tt::target::TensorRef *tensorRef) { return ::tt::runtime::ttnn::utils::toTTNNDataType( tensorRef->desc()->layout()->memory_desc()->data_type()); } -::ttnn::Layout -inferLayoutFromTileShape(const ::tt::target::TensorRef *tensorRef) { - const ::tt::target::Dim2d *tileShape = - tensorRef->desc()->layout()->memory_desc()->tile_shape(); - LOG_ASSERT(::tt::runtime::ttnn::utils::isValidTileShape(tileShape)); - if (tileShape->x() == 1 and tileShape->y() == 1) { - return ::ttnn::Layout::ROW_MAJOR; - } - return ::ttnn::Layout::TILE; -} - -CoreRangeSet -toCoreRangeSet(const ::flatbuffers::Vector - *coreRangeSet) { - std::set coreRanges; - for (::tt::target::Dim2dRange const *coreRange : *coreRangeSet) { - CoreCoord start(coreRange->loc().x(), coreRange->loc().y()); - // End is inclusive - CoreCoord end(coreRange->loc().x() + coreRange->size().x() - 1, - coreRange->loc().y() + coreRange->size().y() - 1); - - coreRanges.emplace(start, end); - } - return CoreRangeSet(coreRanges); -} - -// This method will soon be deprecated, prefer to use the method below -// -::tt::tt_metal::MemoryConfig -createMemoryConfig(const ::tt::target::TensorRef *tensorRef) { - const ::tt::target::LayoutDesc *layout = tensorRef->desc()->layout(); - const ::tt::target::TensorMemoryLayout targetMemoryLayout = - layout->memory_desc()->memory_layout(); - const ::tt::target::MemorySpace targetMemorySpace = - layout->memory_desc()->memory_space(); - const ::flatbuffers::Vector - *targetCoreRangeSet = layout->core_range_set(); - const ::flatbuffers::Vector *targetShardShape = - layout->memory_desc()->shape(); - const ::tt::target::Dim2d *tileShape = layout->memory_desc()->tile_shape(); - - LOG_ASSERT(targetCoreRangeSet->size() == 1, - "Currently only single core range/grid is supported"); - - LOG_ASSERT(targetShardShape->size() == 2, - "Only 2D shard shape is supported in TTNN backend"); - - LOG_ASSERT(::tt::runtime::ttnn::utils::isValidTileShape(tileShape), - "Invalid tile shape"); - - CoreRangeSet ttnnCoreRangeSet = toCoreRangeSet(targetCoreRangeSet); - std::array ttnnShardShape; - std::copy(targetShardShape->begin(), targetShardShape->end(), - ttnnShardShape.begin()); - - ttnnShardShape[0] *= tileShape->y(); - ttnnShardShape[1] *= tileShape->x(); - - ::tt::tt_metal::ShardSpec shardSpec( - ttnnCoreRangeSet, ttnnShardShape, - ::tt::tt_metal::ShardOrientation::ROW_MAJOR, false); - - ::tt::tt_metal::TensorMemoryLayout ttnnMemLayout = - ::tt::runtime::ttnn::utils::toTTNNTensorMemoryLayout(targetMemoryLayout); - - ::tt::tt_metal::BufferType ttnnBufferType = - ::tt::runtime::ttnn::utils::toTTNNBufferType(targetMemorySpace); - - return {ttnnMemLayout, ttnnBufferType, shardSpec}; -} - -// Prefer to use this method over the one above -// ::tt::tt_metal::MemoryConfig createMemoryConfig(const ::tt::target::MemoryConfigDesc *memcfg, const ::tt::target::TensorRef *tensorRef) { @@ -144,7 +43,8 @@ createMemoryConfig(const ::tt::target::MemoryConfigDesc *memcfg, const ::tt::target::LayoutDesc *layout = tensorRef->desc()->layout(); const ::flatbuffers::Vector *targetCoreRangeSet = layout->core_range_set(); - CoreRangeSet ttnnCoreRangeSet = toCoreRangeSet(targetCoreRangeSet); + CoreRangeSet ttnnCoreRangeSet = + ::tt::runtime::ttnn::utils::toCoreRangeSet(targetCoreRangeSet); const ::flatbuffers::Vector *shardShape = memcfg->shard_spec()->shard_shape(); const ::tt::target::Dim2d *tileShape = layout->memory_desc()->tile_shape(); @@ -169,8 +69,11 @@ createMemoryConfig(const ::tt::target::MemoryConfigDesc *memcfg, ttnnCoreRangeSet, ttnnShardShape, ::tt::tt_metal::ShardOrientation::ROW_MAJOR, false); - ::ttnn::MemoryConfig memoryConfig = {tensorMemoryLayout, bufferType, - shardSpec}; + ::ttnn::MemoryConfig memoryConfig = { + tensorMemoryLayout, bufferType, + tensorMemoryLayout == tt_metal::TensorMemoryLayout::INTERLEAVED + ? std::nullopt + : std::make_optional(shardSpec)}; return memoryConfig; } diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h index b922e120a6..269e0328f9 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h @@ -13,32 +13,15 @@ namespace tt::runtime::ttnn::operations::utils { -bool isOnHost(const ::ttnn::Tensor &tensor); - -bool isOnDevice(const ::ttnn::Tensor &tensor); - bool isTilized(const ::tt::target::TensorRef *tensorRef); bool inSystemMemory(const ::tt::target::TensorRef *tensorRef); -void updateTensorPool(ProgramTensorPool &tensorPool, - const ::ttnn::Tensor &tensor, uint32_t outputGlobalId); - ::tt::target::MemorySpace getMemorySpace(const ::tt::target::TensorRef *tensorRef); ::ttnn::DataType getDataType(const ::tt::target::TensorRef *tensorRef); -::ttnn::Layout -inferLayoutFromTileShape(const ::tt::target::TensorRef *tensorRef); - -CoreRangeSet -toCoreRangeSet(const ::flatbuffers::Vector - *coreRangeSet); - -::tt::tt_metal::MemoryConfig -createMemoryConfig(const ::tt::target::TensorRef *tensorRef); - ::tt::tt_metal::MemoryConfig createMemoryConfig(const ::tt::target::MemoryConfigDesc *memcfg, const ::tt::target::TensorRef *tensorRef); diff --git a/runtime/lib/ttnn/operations/kv_cache/fill_cache.cpp b/runtime/lib/ttnn/operations/kv_cache/fill_cache.cpp new file mode 100644 index 0000000000..89022f64a1 --- /dev/null +++ b/runtime/lib/ttnn/operations/kv_cache/fill_cache.cpp @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "fill_cache.h" + +namespace tt::runtime::ttnn::operations::kv_cache { +void run(const ::tt::target::ttnn::FillCacheOp *op, ProgramContext &context) { + + ProgramTensorPool &tensorPool = context.getTensorPool(); + const ::ttnn::Tensor &cache = tensorPool.at(op->cache()->global_id()); + const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id()); + + ::ttnn::fill_cache(cache, input, op->batch_offset()); +} +} // namespace tt::runtime::ttnn::operations::kv_cache diff --git a/runtime/lib/ttnn/operations/kv_cache/fill_cache.h b/runtime/lib/ttnn/operations/kv_cache/fill_cache.h new file mode 100644 index 0000000000..4187cb604b --- /dev/null +++ b/runtime/lib/ttnn/operations/kv_cache/fill_cache.h @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef RUNTIME_LIB_TTNN_OPERATIONS_FILL_CACHE_H +#define RUNTIME_LIB_TTNN_OPERATIONS_FILL_CACHE_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::kv_cache { +void run(const ::tt::target::ttnn::FillCacheOp *op, ProgramContext &context); +} // namespace tt::runtime::ttnn::operations::kv_cache + +#endif diff --git a/runtime/lib/ttnn/operations/kv_cache/update_cache.cpp b/runtime/lib/ttnn/operations/kv_cache/update_cache.cpp new file mode 100644 index 0000000000..fae1da40c6 --- /dev/null +++ b/runtime/lib/ttnn/operations/kv_cache/update_cache.cpp @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "update_cache.h" + +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/detail/workarounds.h" + +namespace tt::runtime::ttnn::operations::kv_cache { +void run(const ::tt::target::ttnn::UpdateCacheOp *op, ProgramContext &context) { + + ProgramTensorPool &tensorPool = context.getTensorPool(); + + const ::ttnn::Tensor &cache = tensorPool.at(op->cache()->global_id()); + const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id()); + const ::ttnn::Tensor &updateIndex = + tensorPool.at(op->update_index()->global_id()); + if (workaround::Env::get().readUpdateIndexFromDeviceForKVCache) { + + const ::ttnn::Tensor indexOnHost = ::ttnn::from_device(updateIndex); + const auto storage = indexOnHost.get_storage(); + const auto ownedStorage = std::get(storage); + const auto buffer = ownedStorage.get_buffer(); + const auto buf = std::get>(buffer); + uint32_t upIdx = *buf.begin(); + + ::ttnn::update_cache(cache, input, upIdx, op->batch_offset(), std::nullopt); + } else { + LOG_FATAL("Currently, the only way to execute ttnn::update_cache is to use " + "the workaround enabled by the flag " + "\"readUpdateIndexFromDeviceForKVCache\""); + } +} +} // namespace tt::runtime::ttnn::operations::kv_cache diff --git a/runtime/lib/ttnn/operations/kv_cache/update_cache.h b/runtime/lib/ttnn/operations/kv_cache/update_cache.h new file mode 100644 index 0000000000..1c4115f1eb --- /dev/null +++ b/runtime/lib/ttnn/operations/kv_cache/update_cache.h @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef RUNTIME_LIB_TTNN_OPERATIONS_UPDATE_CACHE_H +#define RUNTIME_LIB_TTNN_OPERATIONS_UPDATE_CACHE_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::kv_cache { +void run(const ::tt::target::ttnn::UpdateCacheOp *op, ProgramContext &context); +} // namespace tt::runtime::ttnn::operations::kv_cache + +#endif diff --git a/runtime/lib/ttnn/operations/layout/from_device.cpp b/runtime/lib/ttnn/operations/layout/from_device.cpp index b6820b6ec7..e26e3be2a3 100644 --- a/runtime/lib/ttnn/operations/layout/from_device.cpp +++ b/runtime/lib/ttnn/operations/layout/from_device.cpp @@ -12,10 +12,11 @@ void run(const ::tt::target::ttnn::FromDeviceOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in()->global_id()); DEBUG_ASSERT(inputTensor.is_allocated()); - LOG_ASSERT(utils::isOnDevice(inputTensor), - "Calling ttnn::from_device on a host tensor"); + DEBUG_ASSERT( + ::tt::runtime::ttnn::utils::isOnDevice(inputTensor.storage_type()), + "Calling ttnn::from_device on a host tensor"); ::ttnn::Tensor out = ::ttnn::from_device(inputTensor); - utils::updateTensorPool(tensorPool, out, op->out()->global_id()); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::layout diff --git a/runtime/lib/ttnn/operations/layout/to_device.cpp b/runtime/lib/ttnn/operations/layout/to_device.cpp index 34af89f504..414afc9f05 100644 --- a/runtime/lib/ttnn/operations/layout/to_device.cpp +++ b/runtime/lib/ttnn/operations/layout/to_device.cpp @@ -14,7 +14,7 @@ void run(const ::tt::target::ttnn::ToDeviceOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in()->global_id()); DEBUG_ASSERT(inputTensor.is_allocated()); - DEBUG_ASSERT(utils::isOnHost(inputTensor), + DEBUG_ASSERT(::tt::runtime::ttnn::utils::isOnHost(inputTensor.storage_type()), "Calling ttnn::to_device on a device tensor"); std::optional<::ttnn::MemoryConfig> memoryConfig = std::nullopt; diff --git a/runtime/lib/ttnn/operations/layout/to_layout.cpp b/runtime/lib/ttnn/operations/layout/to_layout.cpp index 5e78a67187..bf80ef292e 100644 --- a/runtime/lib/ttnn/operations/layout/to_layout.cpp +++ b/runtime/lib/ttnn/operations/layout/to_layout.cpp @@ -57,7 +57,7 @@ void run(const ::tt::target::ttnn::ToLayoutOp *op, ProgramContext &context) { out = ::ttnn::to_layout(inputTensor, layout, dtype, memoryConfig, static_cast<::ttnn::Device *>(nullptr)); } - utils::updateTensorPool(tensorPool, out, op->out()->global_id()); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::layout diff --git a/runtime/lib/ttnn/operations/layout/typecast.cpp b/runtime/lib/ttnn/operations/layout/typecast.cpp index 5529c6112c..e59a64a401 100644 --- a/runtime/lib/ttnn/operations/layout/typecast.cpp +++ b/runtime/lib/ttnn/operations/layout/typecast.cpp @@ -17,7 +17,7 @@ void run(const ::tt::target::ttnn::TypecastOp *op, ProgramContext &context) { ::tt::runtime::ttnn::utils::toTTNNDataType(op->dtype()); ::ttnn::Tensor out = ::ttnn::typecast(inputTensor, targetDataType); - utils::updateTensorPool(tensorPool, out, op->out()->global_id()); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::layout diff --git a/runtime/lib/ttnn/operations/matmul/matmul.cpp b/runtime/lib/ttnn/operations/matmul/matmul.cpp index abe71f9707..896797d59c 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.cpp +++ b/runtime/lib/ttnn/operations/matmul/matmul.cpp @@ -6,10 +6,11 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" #include -// ANCHOR: adding_an_op_matmul_runtime_operations namespace tt::runtime::ttnn::operations::matmul { +// ANCHOR: adding_an_op_matmul_runtime_operations void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id()); @@ -18,11 +19,7 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { DEBUG_ASSERT(rhs.is_allocated()); ::ttnn::DataType outputDataType = utils::getDataType(op->out()); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); - - std::optional< - ::ttnn::operations::matmul::MatmulMultiCoreReuseMultiCast1DProgramConfig> - programConfig = std::nullopt; + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); const std::optional memoryConfig = std::make_optional(outputMemoryConfig); @@ -37,5 +34,35 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { tensorPool.insert_or_assign(op->out()->global_id(), out); } -} // namespace tt::runtime::ttnn::operations::matmul // ANCHOR_END: adding_an_op_matmul_runtime_operations + +void run(const ::tt::target::ttnn::LinearOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id()); + const ::ttnn::Tensor &rhs = tensorPool.at(op->in1()->global_id()); + std::optional<::ttnn::Tensor> bias = + op->bias() ? std::make_optional(tensorPool.at(op->bias()->global_id())) + : std::nullopt; + + DEBUG_ASSERT(lhs.is_allocated()); + DEBUG_ASSERT(rhs.is_allocated()); + DEBUG_ASSERT(!bias || bias->is_allocated()); + + ::ttnn::DataType outputDataType = utils::getDataType(op->out()); + ::tt::tt_metal::MemoryConfig outputMemoryConfig = + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); + + const std::optional memoryConfig = + std::make_optional(outputMemoryConfig); + + const std::optional dtype = + std::make_optional(outputDataType); + + ::ttnn::Tensor out = ::ttnn::linear( + lhs, rhs, bias, /*transposeA*/ false, /*transposeB*/ false, memoryConfig, + dtype, /*programConfig*/ std::nullopt, /*activation*/ std::nullopt, + /*computeKernelConfig*/ std::nullopt, /*coreGrid*/ std::nullopt); + + tensorPool.insert_or_assign(op->out()->global_id(), out); +} +} // namespace tt::runtime::ttnn::operations::matmul diff --git a/runtime/lib/ttnn/operations/matmul/matmul.h b/runtime/lib/ttnn/operations/matmul/matmul.h index 5957a54a3c..7b0583786b 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.h +++ b/runtime/lib/ttnn/operations/matmul/matmul.h @@ -10,6 +10,7 @@ namespace tt::runtime::ttnn::operations::matmul { void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context); +void run(const ::tt::target::ttnn::LinearOp *op, ProgramContext &context); } // namespace tt::runtime::ttnn::operations::matmul #endif diff --git a/runtime/lib/ttnn/operations/normalization/softmax.cpp b/runtime/lib/ttnn/operations/normalization/softmax.cpp index a83358567c..432f920956 100644 --- a/runtime/lib/ttnn/operations/normalization/softmax.cpp +++ b/runtime/lib/ttnn/operations/normalization/softmax.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::normalization { void run(const ::tt::target::ttnn::SoftmaxOp *op, ProgramContext &context) { @@ -14,7 +15,7 @@ void run(const ::tt::target::ttnn::SoftmaxOp *op, ProgramContext &context) { DEBUG_ASSERT(in.is_allocated()); int32_t dimension = op->dimension(); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ::ttnn::softmax(in, dimension, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); } diff --git a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp index dfd8b9375e..a20bdc51b4 100644 --- a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp +++ b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp @@ -31,12 +31,14 @@ preshardForMaxPool2d(const ::tt::target::ttnn::MaxPool2dOp *op, op->dilation_width() * (op->kernel_width() - 1) - 1) / op->stride_width(); - auto parallel_config = - ::ttnn::operations::conv::conv2d::determine_parallel_config( - ::ttnn::TensorMemoryLayout::HEIGHT_SHARDED, op->batch_size(), - op->channels(), output_height, output_width, op->channels(), - device.compute_with_storage_grid_size(), ShardOrientation::ROW_MAJOR); - auto sharded_memory_config = ::ttnn::operations::conv::conv2d:: + constexpr bool en_ch_padding = false; + + auto parallel_config = ::ttnn::operations::conv::determine_parallel_config( + ::ttnn::TensorMemoryLayout::HEIGHT_SHARDED, op->batch_size(), + op->channels(), output_height, output_width, op->channels(), + device.compute_with_storage_grid_size(), ShardOrientation::ROW_MAJOR, + en_ch_padding); + auto sharded_memory_config = ::ttnn::operations::conv:: create_sharded_memory_config_from_parallel_config(inputShape, parallel_config, 1); return ::ttnn::to_memory_config(input, sharded_memory_config, std::nullopt); @@ -44,8 +46,10 @@ preshardForMaxPool2d(const ::tt::target::ttnn::MaxPool2dOp *op, void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); - const ::ttnn::operations::pool::MaxPool2DOp operation = - ::ttnn::operations::pool::MaxPool2DOp(); + const ::ttnn::operations::pool::Pool2DOp< + ::ttnn::operations::pool::Pool2DType::MAX_POOL2D> + operation = ::ttnn::operations::pool::Pool2DOp< + ::ttnn::operations::pool::Pool2DType::MAX_POOL2D>(); ::ttnn::Tensor input = tensorPool.at(op->in()->global_id()); DEBUG_ASSERT(input.is_allocated()); @@ -58,7 +62,8 @@ void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) { }, targetDevice); } - ::ttnn::MemoryConfig outMemConfig = utils::createMemoryConfig(op->out()); + ::ttnn::MemoryConfig outMemConfig = + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = operation.invoke( 0, input, op->batch_size(), op->input_height(), op->input_width(), op->channels(), {op->kernel_height(), op->kernel_width()}, diff --git a/runtime/lib/ttnn/operations/reduction/reduction.cpp b/runtime/lib/ttnn/operations/reduction/reduction.cpp index 3af46efc9c..a74373ee9f 100644 --- a/runtime/lib/ttnn/operations/reduction/reduction.cpp +++ b/runtime/lib/ttnn/operations/reduction/reduction.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::reduction { static void runReductionOp( @@ -17,7 +18,7 @@ static void runReductionOp( const std::optional<::ttnn::DeviceComputeKernelConfig> &, float)> &ttnnOp) { ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); DEBUG_ASSERT(in.is_allocated()); diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 8cfa013891..f38bfe83ce 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -4,6 +4,7 @@ #include "operations/ccl/all_gather.h" #include "operations/context/get_device.h" #include "operations/conv/conv2d.h" +#include "operations/creation/arange.h" #include "operations/creation/empty.h" #include "operations/creation/full.h" #include "operations/data_movement/concat.h" @@ -17,6 +18,8 @@ #include "operations/eltwise/unary/unary.h" #include "operations/eltwise/unary/unary_composite.h" #include "operations/embedding/embedding.h" +#include "operations/kv_cache/fill_cache.h" +#include "operations/kv_cache/update_cache.h" #include "operations/layout/from_device.h" #include "operations/layout/to_device.h" #include "operations/layout/to_layout.h" @@ -29,56 +32,60 @@ #include "tt/runtime/detail/debug.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/ttnn/types.h" +#include "tt/runtime/ttnn/utils.h" #include "tt/runtime/utils.h" #include "ttmlir/Target/TTNN/program_generated.h" +#ifdef TT_RUNTIME_ENABLE_PERF_TRACE +#include "tracy/Tracy.hpp" +#endif + namespace tt::runtime::ttnn { using LogType = ::tt::runtime::logger::LogType; +void tracyLogOpLocation(const ::tt::target::ttnn::Operation *op) { +#ifdef TT_RUNTIME_ENABLE_PERF_TRACE + TracyMessage(op->loc_info()->c_str(), op->loc_info()->size()); +#endif +} + static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier( binary.handle.get()); - if (not isTTNN) { - throw std::runtime_error("Unsupported binary format"); - } + LOG_ASSERT(isTTNN, "Unsupported binary format"); return ::tt::target::ttnn::GetSizePrefixedTTNNBinary(binary.handle.get()); } class ProgramExecutor { public: - ProgramExecutor(Binary &executableHandle, const TensorMap &liveTensors, - const std::unordered_set &programInputs, - const std::unordered_set &programOutputs, - ::ttnn::MeshDevice *meshDevice) + ProgramExecutor( + const Binary &executableHandle, + const std::unordered_map &liveTensors, + const std::vector &programInputs, + const std::vector &programOutputs, + ::ttnn::MeshDevice *meshDevice) : executableHandle(executableHandle), context(ProgramContext(liveTensors, programInputs, programOutputs, meshDevice)) {} void runCallback(Binary &executableHandle, const ::tt::target::ttnn::Operation *opContext, - ProgramContext *programContext) { - if (auto callback = debug::Hooks::get().getOperatorCallback(); callback) { - std::shared_ptr programContextPtr = - ::tt::runtime::utils::unsafe_borrow_shared(programContext); - std::shared_ptr opContextPtr = - ::tt::runtime::utils::unsafe_borrow_shared( - const_cast<::tt::target::ttnn::Operation *>(opContext)); - (*callback)(executableHandle, - CallbackContext(programContextPtr, DeviceRuntime::TTNN), - OpContext(opContextPtr, DeviceRuntime::TTNN)); - } - } + ProgramContext *programContext); void execute(const ::tt::target::ttnn::Program *program) { for (const ::tt::target::ttnn::Operation *op : *program->operations()) { LOG_DEBUG(LogType::LogRuntimeTTNN, "Executing operation: ", op->debug_info()->c_str()); + tracyLogOpLocation(op); runOperation(op); runCallback(executableHandle, op, &context); } } ProgramContext &getContext() { return context; } + std::vector gatherOutputTensors() { + return context.getTensorPool().gatherOutputTensors(); + } private: Binary executableHandle; @@ -87,6 +94,21 @@ class ProgramExecutor { void runEltwiseOperation(const ::tt::target::ttnn::EltwiseOp *op); }; +void ProgramExecutor::runCallback( + Binary &executableHandle, const ::tt::target::ttnn::Operation *opContext, + ProgramContext *programContext) { + if (auto callback = debug::Hooks::get().getOperatorCallback(); callback) { + std::shared_ptr programContextPtr = + ::tt::runtime::utils::unsafe_borrow_shared(programContext); + std::shared_ptr opContextPtr = + ::tt::runtime::utils::unsafe_borrow_shared( + const_cast<::tt::target::ttnn::Operation *>(opContext)); + (*callback)(executableHandle, + CallbackContext(programContextPtr, DeviceRuntime::TTNN), + OpContext(opContextPtr, DeviceRuntime::TTNN)); + } +} + void ProgramExecutor::runEltwiseOperation( const ::tt::target::ttnn::EltwiseOp *op) { auto runUnaryOp = [&]() { @@ -148,6 +170,9 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { case ::tt::target::ttnn::OpType::EltwiseOp: { return runEltwiseOperation(op->type_as_EltwiseOp()); } + case ::tt::target::ttnn::OpType::LinearOp: { + return operations::matmul::run(op->type_as_LinearOp(), context); + } // ANCHOR: adding_an_op_matmul_runtime_program case ::tt::target::ttnn::OpType::MatmulOp: { return operations::matmul::run(op->type_as_MatmulOp(), context); @@ -186,6 +211,15 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { case ::tt::target::ttnn::OpType::AllGatherOp: { return operations::ccl::run(op->type_as_AllGatherOp(), context); } + case ::tt::target::ttnn::OpType::ArangeOp: { + return operations::creation::run(op->type_as_ArangeOp(), context); + } + case ::tt::target::ttnn::OpType::UpdateCacheOp: { + return operations::kv_cache::run(op->type_as_UpdateCacheOp(), context); + } + case ::tt::target::ttnn::OpType::FillCacheOp: { + return operations::kv_cache::run(op->type_as_FillCacheOp(), context); + } default: { LOG_FATAL("Unsupported operation type"); } @@ -193,6 +227,26 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { } // Nop is single input, output tensor where input is returned as output. +static bool isNopProgram(const ::tt::target::ttnn::Program *program) { + return program->inputs()->size() == 1 && program->outputs()->size() == 1 && + program->inputs()->Get(0)->global_id() == + program->outputs()->Get(0)->global_id(); +} + +static ::ttnn::Tensor +handleNopProgram(::tt::target::ttnn::Program const *program, + std::vector<::ttnn::Tensor *> const &inputs) { + const ::ttnn::Tensor &input = *inputs[0]; + ::ttnn::Tensor output = + ::ttnn::zeros(input.get_shape(), input.get_dtype(), input.get_layout()); + const void *src = ::tt::tt_metal::get_raw_host_data_ptr(input); + void *dst = ::tt::tt_metal::get_raw_host_data_ptr(output); + std::memcpy(dst, src, input.volume() * input.element_size()); + return output; +} + +namespace legacy { + static bool handleNopProgram(::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs, std::vector<::ttnn::Tensor *> const &outputs) { @@ -221,8 +275,8 @@ void runProgram(::ttnn::MeshDevice &meshDevice, Binary &executableHandle, if (handleNopProgram(program, inputs, outputs)) { return; } - TensorMap liveTensors; - std::unordered_set programInputs; + std::unordered_map liveTensors; + std::vector programInputs; int inputIndex = 0; LOG_ASSERT(program->inputs()->size() == inputs.size(), "Program input size mismatch: ", program->inputs()->size(), @@ -231,21 +285,69 @@ void runProgram(::ttnn::MeshDevice &meshDevice, Binary &executableHandle, auto [iter, inserted] = liveTensors.try_emplace(input->global_id(), inputs[inputIndex++]); LOG_ASSERT(inserted, "Duplicate input tensor"); - programInputs.emplace(input->global_id()); + programInputs.push_back(input->global_id()); } int outputIndex = 0; - std::unordered_set programOutputs; + std::vector programOutputs; LOG_ASSERT(program->outputs()->size() == outputs.size()); for (::tt::target::TensorRef const *output : *program->outputs()) { auto [iter, inserted] = liveTensors.try_emplace(output->global_id(), outputs[outputIndex++]); LOG_ASSERT(inserted, "Duplicate output tensor"); - programOutputs.emplace(output->global_id()); + programOutputs.push_back(output->global_id()); + } + ProgramExecutor executor(executableHandle, liveTensors, programInputs, + programOutputs, &meshDevice); + executor.execute(program); + outputIndex = 0; + for (uint32_t outputId : programOutputs) { + const ::ttnn::Tensor &src = + executor.getContext().getTensorPool().at(outputId); + const ::ttnn::Tensor &dst = *(outputs[outputIndex++]); + size_t srcSize = src.volume() * src.element_size(); + size_t dstSize = dst.volume() * dst.element_size(); + LOG_ASSERT(srcSize == dstSize, "Output tensor size mismatch"); + const void *srcPtr = ::tt::tt_metal::get_raw_host_data_ptr(src); + void *dstPtr = ::tt::tt_metal::get_raw_host_data_ptr(dst); + std::memcpy(dstPtr, srcPtr, dstSize); + } +} +} // namespace legacy + +std::vector runProgram(::ttnn::MeshDevice &meshDevice, + Binary executableHandle, + std::uint32_t programIndex, + std::vector<::ttnn::Tensor *> const &inputs) { + ::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle); + ::tt::target::ttnn::Program const *program = + fbb.programs()->Get(programIndex); + if (isNopProgram(program)) { + Tensor out = + utils::createRuntimeTensorFromTTNN(handleNopProgram(program, inputs)); + return {out}; + } + std::unordered_map liveTensors; + std::vector programInputs; + int inputIndex = 0; + LOG_ASSERT(program->inputs()->size() == inputs.size(), + "Program input size mismatch: ", program->inputs()->size(), + " != ", inputs.size()); + for (::tt::target::TensorRef const *input : *program->inputs()) { + auto [iter, inserted] = + liveTensors.try_emplace(input->global_id(), inputs[inputIndex++]); + LOG_ASSERT(inserted, "Duplicate input tensor"); + programInputs.push_back(input->global_id()); + } + std::vector programOutputs; + for (::tt::target::TensorRef const *output : *program->outputs()) { + programOutputs.push_back(output->global_id()); } ProgramExecutor executor(executableHandle, liveTensors, programInputs, programOutputs, &meshDevice); executor.execute(program); + std::vector outputTensors = executor.gatherOutputTensors(); + return outputTensors; } } // namespace tt::runtime::ttnn diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 86fd2d25c6..0578557851 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -1,7 +1,6 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -#include "tt/runtime/runtime.h" #include "tt/runtime/detail/debug.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" @@ -21,16 +20,28 @@ using ::tt::tt_metal::DistributedTensorConfig; using ::tt::tt_metal::OwnedStorage; using ::tt::tt_metal::raise_unsupported_storage; +template +static OwnedStorage createOwnedStorage(ElementType *ptr, + std::uint32_t numElements) { + ::tt::tt_metal::owned_buffer::Buffer buffer; + if (ptr != nullptr) { + auto data = std::vector(ptr, ptr + numElements); + buffer = ::tt::tt_metal::owned_buffer::create(std::move(data)); + } else { + buffer = ::tt::tt_metal::owned_buffer::create(numElements); + } + return OwnedStorage(std::move(buffer)); +} + template static StorageType createStorage(ElementType *ptr, std::uint32_t numElements) { if constexpr (std::is_same_v) { + LOG_ASSERT(ptr != nullptr, "Cannot create borrowed storage from nullptr"); return BorrowedStorage( ::tt::tt_metal::borrowed_buffer::Buffer(ptr, numElements), [] {}, [] {}); } else if constexpr (std::is_same_v) { - auto data = std::vector(ptr, ptr + numElements); - auto buffer = ::tt::tt_metal::owned_buffer::create(std::move(data)); - return OwnedStorage(std::move(buffer)); + return createOwnedStorage(ptr, numElements); } else { raise_unsupported_storage(); } @@ -76,6 +87,21 @@ static Tensor createNullTensor() { return Tensor(nullptr, nullptr, DeviceRuntime::TTNN); } +static DeviceVariant getTargetDevice(::ttnn::MeshDevice &meshDevice) { + if (meshDevice.num_devices() == 1) { + return std::ref(*(meshDevice.get_device_index(0))); + } + return std::ref(meshDevice); +} + +static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { + bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier( + binary.handle.get()); + LOG_ASSERT(isTTNN, "Unsupported binary format"); + return ::tt::target::ttnn::GetSizePrefixedTTNNBinary(binary.handle.get()); +} + +// Create a borrowed tensor from user-owned data Tensor createTensor(std::shared_ptr data, std::vector const &shape, std::vector const &stride, @@ -89,10 +115,11 @@ Tensor createTensor(std::shared_ptr data, createStorage(data.get(), numElements, dataType), ::ttnn::Shape(small_vector_shape), utils::toTTNNDataType(dataType), ::ttnn::Layout::ROW_MAJOR); - return Tensor(std::static_pointer_cast(tensor), data, + return Tensor(std::static_pointer_cast(tensor), nullptr, DeviceRuntime::TTNN); } +// Create a owned multi-device host tensor from user-owned data Tensor createTensor(std::vector> &data, std::vector const &shape, @@ -100,8 +127,8 @@ createTensor(std::vector> &data, ::tt::target::DataType dataType, std::unordered_map const &strategy) { std::vector<::ttnn::Tensor> tensorShards; - tensorShards.resize(data.size()); - std::transform(data.begin(), data.end(), tensorShards.begin(), + tensorShards.reserve(data.size()); + std::transform(data.begin(), data.end(), std::back_inserter(tensorShards), [&](std::shared_ptr &dataShard) -> ::ttnn::Tensor { return createOwnedTensor(dataShard, shape, stride, itemsize, dataType); @@ -112,13 +139,35 @@ createTensor(std::vector> &data, ::ttnn::distributed::api::create_multi_device_tensor( tensorShards, ::tt::tt_metal::StorageType::MULTI_DEVICE_HOST, distributionStrategy)); - std::shared_ptr>> borrowedData = - std::make_shared>>(data); - return Tensor(std::static_pointer_cast(tensor), - std::static_pointer_cast(borrowedData), + return Tensor(std::static_pointer_cast(tensor), nullptr, DeviceRuntime::TTNN); } +// Create an owned empty tensor on host/device +Tensor createTensor(Device device, Layout layout, + std::vector const &shape, + std::vector const &stride, + std::uint32_t itemsize) { + const LayoutDesc &layoutDesc = layout.as(DeviceRuntime::TTNN); + if (layoutDesc.isOnHost()) { + ::ttnn::Tensor tensor = + createOwnedTensor(nullptr, shape, stride, itemsize, + utils::fromTTNNDataType(layoutDesc.dataType)); + Tensor out = utils::createRuntimeTensorFromTTNN(tensor); + return toLayout(out, device, layout); + } + DeviceVariant targetDevice = + getTargetDevice(device.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN)); + ::ttnn::Tensor tensor = std::visit( + [&](auto &&device) -> ::ttnn::Tensor { + return ::ttnn::operations::core::allocate_tensor_on_device( + ::ttnn::Shape(shape), layoutDesc.dataType, layoutDesc.layout, + &(device.get()), layoutDesc.memoryConfig); + }, + targetDevice); + return utils::createRuntimeTensorFromTTNN(tensor); +} + tt::target::DataType getTensorDataType(Tensor tensor) { const ::ttnn::Tensor &nnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); @@ -166,34 +215,120 @@ void deallocateBuffers(Device deviceHandle) { } } -Event submit(Device deviceHandle, Binary executableHandle, - std::uint32_t programIndex, - std::vector const &inputHandles, - std::vector const &outputHandles) { - ::ttnn::MeshDevice &meshDevice = - deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); - std::vector<::ttnn::Tensor *> inputs; - inputs.reserve(inputHandles.size()); - for (auto &input : inputHandles) { - LOG_ASSERT(input.matchesRuntime(DeviceRuntime::TTNN)); - inputs.push_back(static_cast<::ttnn::Tensor *>(input.handle.get())); +void wait(Event event) { + // Nothing to do for ttnn runtime + LOG_ASSERT(event.matchesRuntime(DeviceRuntime::TTNN)); +} + +void wait(Tensor tensor) { + LOG_ASSERT(tensor.matchesRuntime(DeviceRuntime::TTNN), + "Expected ttnn tensor"); + ::tt::runtime::ttnn::wait(tensor.event); +} + +void wait(std::vector const &tensors) { + for (const Tensor &tensor : tensors) { + ::tt::runtime::ttnn::wait(tensor); } +} - std::vector<::ttnn::Tensor *> outputs; - outputs.reserve(outputHandles.size()); - for (auto &output : outputHandles) { - LOG_ASSERT(output.matchesRuntime(DeviceRuntime::TTNN)); - outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get())); +Tensor toHost(Tensor tensor, bool untilize) { + const ::ttnn::Tensor &deviceTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + std::shared_ptr<::ttnn::Tensor> hostTensor = + std::make_shared<::ttnn::Tensor>(::ttnn::from_device(deviceTensor)); + + if (untilize) { + hostTensor = std::make_shared<::ttnn::Tensor>(::ttnn::to_layout( + *hostTensor, ::ttnn::Layout::ROW_MAJOR, std::nullopt, std::nullopt, + static_cast<::ttnn::Device *>(nullptr))); } - tt::runtime::ttnn::runProgram(meshDevice, executableHandle, programIndex, - inputs, outputs); - return Event(nullptr, DeviceRuntime::TTNN); + return Tensor(std::static_pointer_cast(hostTensor), nullptr, + DeviceRuntime::TTNN); } -void wait(Event event) { - // Not implemented - LOG_ASSERT(event.matchesRuntime(DeviceRuntime::TTNN)); +Tensor toLayout(Tensor tensor, Device device, Layout layout) { + const ::ttnn::Tensor &ttnnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + const ::ttnn::Layout &inputLayout = ttnnTensor.get_layout(); + const ::ttnn::DataType &inputDataType = ttnnTensor.get_dtype(); + LayoutDesc inputLayoutDesc(::ttnn::BufferType::SYSTEM_MEMORY, inputLayout, + inputDataType, std::nullopt); + + const LayoutDesc &outputLayoutDesc = + layout.as(DeviceRuntime::TTNN); + + ::ttnn::MeshDevice &meshDevice = + device.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); + DeviceVariant targetDevice = getTargetDevice(meshDevice); + LayoutConverter converter(inputLayoutDesc, outputLayoutDesc); + std::shared_ptr<::ttnn::Tensor> out = std::make_shared<::ttnn::Tensor>( + converter.convertTensorLayout(ttnnTensor, targetDevice)); + + return Tensor(std::static_pointer_cast(out), nullptr, + DeviceRuntime::TTNN); +} + +Layout getLayout(Binary executableHandle, std::uint32_t programIndex, + std::uint32_t inputIndex) { + const ::tt::target::ttnn::TTNNBinary &fbb = *getBinary(executableHandle); + LOG_ASSERT(programIndex < fbb.programs()->size(), "Invalid program index"); + const ::tt::target::ttnn::Program *program = + fbb.programs()->Get(programIndex); + LOG_ASSERT(inputIndex < program->inputs()->size(), "Invalid input index"); + const ::tt::target::TensorRef *input = program->inputs()->Get(inputIndex); + + ::ttnn::BufferType inputBufferType = utils::toTTNNBufferType( + input->desc()->layout()->memory_desc()->memory_space()); + ::ttnn::Layout inputLayout = utils::inferLayoutFromTileShape(input); + ::ttnn::DataType inputDataType = utils::toTTNNDataType( + input->desc()->layout()->memory_desc()->data_type()); + std::optional<::ttnn::MemoryConfig> inputMemoryConfig = std::nullopt; + if (inputBufferType != ::ttnn::BufferType::SYSTEM_MEMORY) { + inputMemoryConfig = utils::createMemoryConfig(input); + } + + std::shared_ptr layoutDesc = std::make_shared( + inputBufferType, inputLayout, inputDataType, inputMemoryConfig); + + return Layout(std::static_pointer_cast(layoutDesc), + DeviceRuntime::TTNN); +} + +void memcpy(void *dst, Tensor src) { + const ::ttnn::Tensor &srcTensor = src.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + if (utils::isOnHost(srcTensor.storage_type())) { + const void *srcPtr = ::tt::tt_metal::get_raw_host_data_ptr(srcTensor); + size_t size = srcTensor.volume() * srcTensor.element_size(); + std::memcpy(dst, srcPtr, size); + } else { + ::tt::tt_metal::memcpy(dst, srcTensor); + } +} + +void memcpy(Tensor dst, Tensor src) { + ::ttnn::Tensor &dstTensor = dst.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + const ::ttnn::Tensor &srcTensor = src.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + LOG_ASSERT(srcTensor.volume() * srcTensor.element_size() == + dstTensor.volume() * dstTensor.element_size(), + "Input output tensor size mismatch in memcpy: ", + srcTensor.volume(), " * ", srcTensor.element_size(), + " != ", dstTensor.volume(), " * ", dstTensor.element_size()); + if (utils::isOnHost(srcTensor.storage_type()) and + utils::isOnHost(dstTensor.storage_type())) { + void *dstPtr = ::tt::tt_metal::get_raw_host_data_ptr(dstTensor); + const void *srcPtr = ::tt::tt_metal::get_raw_host_data_ptr(srcTensor); + size_t size = srcTensor.volume() * srcTensor.element_size(); + std::memcpy(dstPtr, srcPtr, size); + } else { + ::tt::tt_metal::memcpy(dstTensor, srcTensor); + } +} + +void deallocateTensor(Tensor &tensor, bool force) { + ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + ::ttnn::deallocate(ttnnTensor, force); } std::string getOpDebugString(OpContext opContextHandle) { @@ -202,6 +337,12 @@ std::string getOpDebugString(OpContext opContextHandle) { return std::string(opContext.debug_info()->c_str()); } +std::string getOpLocInfo(OpContext opContextHandle) { + auto const &opContext = + opContextHandle.as<::tt::target::ttnn::Operation>(DeviceRuntime::TTNN); + return std::string(opContext.loc_info()->c_str()); +} + Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle) { auto const &programContext = @@ -299,7 +440,7 @@ Tensor getOpOutputTensor(OpContext opContextHandle, return createNullTensor(); } default: { - throw std::runtime_error("Unsupported operation type"); + LOG_FATAL("Unsupported operation type"); } } @@ -326,12 +467,13 @@ Tensor getOpOutputTensor(OpContext opContextHandle, outCopy.shape().value, ::ttnn::DataType::FLOAT32, ::ttnn::Layout::ROW_MAJOR); - return Tensor(std::static_pointer_cast(tensor), data, + return Tensor(std::static_pointer_cast(tensor), nullptr, DeviceRuntime::TTNN); } std::vector getTensorData(Tensor tensor) { - ::ttnn::Tensor *nnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + const ::ttnn::Tensor *nnTensor = + static_cast<::ttnn::Tensor *>(tensor.handle.get()); if (nnTensor == nullptr) { return {}; } @@ -341,4 +483,62 @@ std::vector getTensorData(Tensor tensor) { static_cast(dataPtr) + nnTensor->volume()); } +namespace legacy { + +Event submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputHandles, + std::vector const &outputHandles) { + ::ttnn::MeshDevice &meshDevice = + deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); + std::vector<::ttnn::Tensor *> inputs; + inputs.reserve(inputHandles.size()); + for (auto &input : inputHandles) { + LOG_ASSERT(input.matchesRuntime(DeviceRuntime::TTNN)); + inputs.push_back(static_cast<::ttnn::Tensor *>(input.handle.get())); + } + + std::vector<::ttnn::Tensor *> outputs; + outputs.reserve(outputHandles.size()); + for (auto &output : outputHandles) { + LOG_ASSERT(output.matchesRuntime(DeviceRuntime::TTNN)); + outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get())); + } + + tt::runtime::ttnn::legacy::runProgram(meshDevice, executableHandle, + programIndex, inputs, outputs); + return Event(nullptr, DeviceRuntime::TTNN); +} +} // namespace legacy + +std::vector submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputHandles) { + ::ttnn::MeshDevice &meshDevice = + deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); + + // Convert input tensors to the layout expected by the program + std::vector inputsWithLayout; + inputsWithLayout.reserve(inputHandles.size()); + std::transform( + inputHandles.begin(), inputHandles.end(), + std::back_inserter(inputsWithLayout), [&](const Tensor &input) -> Tensor { + Layout inputLayout = ::tt::runtime::ttnn::getLayout( + executableHandle, programIndex, inputsWithLayout.size()); + return ::tt::runtime::ttnn::toLayout(input, deviceHandle, inputLayout); + }); + + std::vector<::ttnn::Tensor *> ttnnInputs; + ttnnInputs.reserve(inputsWithLayout.size()); + std::transform(inputsWithLayout.begin(), inputsWithLayout.end(), + std::back_inserter(ttnnInputs), + [](Tensor &input) -> ::ttnn::Tensor * { + return &input.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + }); + + std::vector outputs = ::tt::runtime::ttnn::runProgram( + meshDevice, executableHandle, programIndex, ttnnInputs); + return outputs; +} + } // namespace tt::runtime::ttnn diff --git a/runtime/test/CMakeLists.txt b/runtime/test/CMakeLists.txt index 8a0d12ee33..f55a6c7615 100644 --- a/runtime/test/CMakeLists.txt +++ b/runtime/test/CMakeLists.txt @@ -1,7 +1,31 @@ +if (NOT TTMLIR_ENABLE_RUNTIME_TESTS) + add_library(TTRuntimeTTNNTestHelpers INTERFACE) + return() +endif() + if (NOT TTMLIR_ENABLE_RUNTIME OR (NOT TT_RUNTIME_ENABLE_TTNN AND NOT TT_RUNTIME_ENABLE_TTMETAL)) message(FATAL_ERROR "Runtime tests require -DTTMLIR_ENABLE_RUNTIME=ON and at least one backend runtime to be enabled") endif() +if (NOT TT_RUNTIME_ENABLE_TTNN) + add_library(TTRuntimeTTNNTestHelpers INTERFACE) +else() + add_library(TTRuntimeTTNNTestHelpers + STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/test/utils.cpp + ) + set_property(TARGET TTRuntimeTTNNTestHelpers PROPERTY CXX_STANDARD 20) + target_compile_options(TTRuntimeTTNNTestHelpers PUBLIC -mavx -mavx2 -fsized-deallocation) + target_include_directories(TTRuntimeTTNNTestHelpers PUBLIC + ${PROJECT_SOURCE_DIR}/runtime/include + ${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/include + ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common + ) + target_include_directories(TTRuntimeTTNNTestHelpers SYSTEM PUBLIC "$") + add_dependencies(TTRuntimeTTNNTestHelpers TTRuntime tt-metal FBS_GENERATION) + target_link_libraries(TTRuntimeTTNNTestHelpers PUBLIC TTRuntime TTNN_LIBRARY) +endif() + enable_testing() include(FetchContent) FetchContent_Declare( @@ -37,6 +61,7 @@ target_include_directories(TTRuntimeTEST INTERFACE target_link_libraries(TTRuntimeTEST INTERFACE TTMETAL_LIBRARY + DEVICE_LIBRARY TTBinary TTRuntime TTRuntimeTTNN diff --git a/runtime/test/include/tt/runtime/ttnn/test/utils.cpp b/runtime/test/include/tt/runtime/ttnn/test/utils.cpp new file mode 100644 index 0000000000..e0cc969b7c --- /dev/null +++ b/runtime/test/include/tt/runtime/ttnn/test/utils.cpp @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt/runtime/test/utils.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/runtime.h" +#include "tt/runtime/ttnn/types.h" +#include "tt/runtime/ttnn/utils.h" +#include "tt/runtime/types.h" + +namespace tt::runtime::ttnn::test { +using ::tt::runtime::DeviceRuntime; +Layout getDramInterleavedTileLayout(::tt::target::DataType dataType) { + LOG_ASSERT(getCurrentRuntime() == DeviceRuntime::TTNN); + ::ttnn::DataType ttnnDataType = + ::tt::runtime::ttnn::utils::toTTNNDataType(dataType); + ::tt::runtime::ttnn::LayoutDesc layoutDesc(::ttnn::BufferType::DRAM, + ::ttnn::Layout::TILE, ttnnDataType, + std::nullopt); + return Layout( + std::static_pointer_cast( + std::make_shared<::tt::runtime::ttnn::LayoutDesc>(layoutDesc)), + ::tt::runtime::DeviceRuntime::TTNN); +} +Layout getDramInterleavedRowMajorLayout(::tt::target::DataType dataType) { + LOG_ASSERT(getCurrentRuntime() == DeviceRuntime::TTNN); + ::ttnn::DataType ttnnDataType = + ::tt::runtime::ttnn::utils::toTTNNDataType(dataType); + ::tt::runtime::ttnn::LayoutDesc layoutDesc(::ttnn::BufferType::DRAM, + ::ttnn::Layout::ROW_MAJOR, + ttnnDataType, std::nullopt); + return Layout( + std::static_pointer_cast( + std::make_shared<::tt::runtime::ttnn::LayoutDesc>(layoutDesc)), + ::tt::runtime::DeviceRuntime::TTNN); +} +::tt::runtime::Layout getHostRowMajorLayout(::tt::target::DataType dataType) { + LOG_ASSERT(getCurrentRuntime() == DeviceRuntime::TTNN); + ::ttnn::DataType ttnnDataType = + ::tt::runtime::ttnn::utils::toTTNNDataType(dataType); + ::tt::runtime::ttnn::LayoutDesc layoutDesc(::ttnn::BufferType::SYSTEM_MEMORY, + ::ttnn::Layout::ROW_MAJOR, + ttnnDataType, std::nullopt); + return Layout( + std::static_pointer_cast( + std::make_shared<::tt::runtime::ttnn::LayoutDesc>(layoutDesc)), + ::tt::runtime::DeviceRuntime::TTNN); +} +} // namespace tt::runtime::ttnn::test diff --git a/runtime/test/python/ttnn/conftest.py b/runtime/test/python/ttnn/conftest.py new file mode 100644 index 0000000000..854cb42a39 --- /dev/null +++ b/runtime/test/python/ttnn/conftest.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +try: + import ttrt +except (ImportError, ModuleNotFoundError): + raise ImportError( + "Error: runtime python tests require ttrt to built and installed. Please run `cmake --build build -- ttrt`" + ) +import ttrt.runtime +from ttrt.common.api import API +from utils import Helper +import pytest + + +@pytest.fixture(autouse=True, scope="module") +def initialize(): + API.initialize_apis() + ttrt.runtime.set_current_runtime(ttrt.runtime.DeviceRuntime.TTNN) + + +@pytest.fixture(scope="module") +def helper(): + helper = Helper() + yield helper diff --git a/runtime/test/python/ttnn/test_runtime_api.py b/runtime/test/python/ttnn/test_runtime_api.py new file mode 100644 index 0000000000..5454cbcd9a --- /dev/null +++ b/runtime/test/python/ttnn/test_runtime_api.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import ttrt +import ttrt.runtime +import torch +from ttrt.common.util import * +from utils import TT_MLIR_HOME, Helper, DeviceContext, assert_pcc + + +@pytest.mark.parametrize("shape", [(64, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_to_layout(helper: Helper, shape, dtype, request): + helper.initialize(request.node.name) + helper.check_constraints() + torch_input_tensor = torch.randn(shape, dtype=dtype) + torch_result_tensor = torch.zeros(shape, dtype=dtype) + runtime_dtype = Binary.Program.to_data_type(dtype) + runtime_input_tensor = ttrt.runtime.create_tensor( + torch_input_tensor.data_ptr(), + list(torch_input_tensor.shape), + list(torch_input_tensor.stride()), + torch_input_tensor.element_size(), + runtime_dtype, + ) + runtime_output_tensor = ttrt.runtime.create_tensor( + torch_result_tensor.data_ptr(), + list(torch_result_tensor.shape), + list(torch_result_tensor.stride()), + torch_result_tensor.element_size(), + runtime_dtype, + ) + device_layout = ttrt.runtime.testing.get_dram_interleaved_tile_layout(runtime_dtype) + host_layout = ttrt.runtime.testing.get_host_row_major_layout(runtime_dtype) + with DeviceContext([helper.query.device_ids[0]]) as device: + device_tensor = ttrt.runtime.to_layout( + runtime_input_tensor, device, device_layout + ) + host_tensor = ttrt.runtime.to_layout(device_tensor, device, host_layout) + ttrt.runtime.deallocate_tensor(device_tensor, force=True) + ttrt.runtime.memcpy(runtime_output_tensor, host_tensor) + ttrt.runtime.deallocate_tensor(host_tensor, force=True) + + assert_pcc(torch_input_tensor, torch_result_tensor, threshold=0.99) + helper.teardown() + + +@pytest.mark.parametrize("shape", [(64, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_memcpy_to_pointer(helper: Helper, shape, dtype, request): + helper.initialize(request.node.name) + helper.check_constraints() + runtime_dtype = Binary.Program.to_data_type(dtype) + torch_result_tensor = torch.zeros(shape, dtype=dtype) + + # Device to host + torch_input_tensor = torch.randn(shape, dtype=dtype) + runtime_input_tensor = ttrt.runtime.create_tensor( + torch_input_tensor.data_ptr(), + list(torch_input_tensor.shape), + list(torch_input_tensor.stride()), + torch_input_tensor.element_size(), + runtime_dtype, + ) + device_layout = ttrt.runtime.testing.get_dram_interleaved_row_major_layout( + runtime_dtype + ) + with DeviceContext([helper.query.device_ids[0]]) as device: + device_tensor = ttrt.runtime.to_layout( + runtime_input_tensor, device, device_layout + ) + ttrt.runtime.memcpy(torch_result_tensor.data_ptr(), device_tensor) + ttrt.runtime.deallocate_tensor(device_tensor, force=True) + + assert_pcc(torch_input_tensor, torch_result_tensor, threshold=0.99) + + # Host to host + torch_input_tensor2 = torch.randn(shape, dtype=dtype) + host_tensor = ttrt.runtime.create_tensor( + torch_input_tensor2.data_ptr(), + list(torch_input_tensor2.shape), + list(torch_input_tensor2.stride()), + torch_input_tensor2.element_size(), + runtime_dtype, + ) + ttrt.runtime.memcpy(torch_result_tensor.data_ptr(), host_tensor) + assert_pcc(torch_input_tensor2, torch_result_tensor, threshold=0.99) + helper.teardown() + + +@pytest.mark.parametrize("shape", [(64, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_create_tensor_memcpy(helper: Helper, shape, dtype, request): + helper.initialize(request.node.name) + helper.check_constraints() + torch_input_tensor = torch.randn(shape, dtype=dtype) + torch_result_tensor = torch.zeros(shape, dtype=dtype) + runtime_dtype = Binary.Program.to_data_type(dtype) + runtime_input_tensor = ttrt.runtime.create_tensor( + torch_input_tensor.data_ptr(), + list(torch_input_tensor.shape), + list(torch_input_tensor.stride()), + torch_input_tensor.element_size(), + runtime_dtype, + ) + runtime_output_tensor = ttrt.runtime.create_tensor( + torch_result_tensor.data_ptr(), + list(torch_result_tensor.shape), + list(torch_result_tensor.stride()), + torch_result_tensor.element_size(), + runtime_dtype, + ) + device_layout = ttrt.runtime.testing.get_dram_interleaved_row_major_layout( + runtime_dtype + ) + with DeviceContext([helper.query.device_ids[0]]) as device: + device_tensor = ttrt.runtime.create_empty_tensor( + device, + device_layout, + list(torch_input_tensor.shape), + list(torch_input_tensor.stride()), + torch_input_tensor.element_size(), + ) + # Copy from host to device container + ttrt.runtime.memcpy(device_tensor, runtime_input_tensor) + # Copy from device to host + ttrt.runtime.memcpy(runtime_output_tensor, device_tensor) + ttrt.runtime.deallocate_tensor(device_tensor, force=True) + assert_pcc(torch_input_tensor, torch_result_tensor, threshold=0.99) + helper.teardown() + + +def test_runtime_stitching_eltwise_binary_op_chain(helper: Helper, request): + binary_path = f"{TT_MLIR_HOME}/build/test/ttmlir/Runtime/TTNN/runtime_stitching/Output/eltwise_binary_op_chain.mlir.tmp.ttnn" + helper.initialize(request.node.name, binary_path) + helper.check_constraints() + first_program: Binary.Program = helper.binary.get_program(0) + assert first_program.num_inputs() == 2 + inputs_torch = [] + inputs_runtime = [] + input_layouts = [] + for i in first_program.program["inputs"]: + torch_tensor = torch.randn( + i["desc"]["shape"], + dtype=Binary.Program.from_data_type( + i["desc"]["layout"]["memory_desc"]["data_type"] + ), + ) + runtime_dtype = Binary.Program.to_data_type(torch_tensor.dtype) + inputs_torch.append(torch_tensor) + runtime_tensor = ttrt.runtime.create_tensor( + torch_tensor.data_ptr(), + list(torch_tensor.shape), + list(torch_tensor.stride()), + torch_tensor.element_size(), + runtime_dtype, + ) + inputs_runtime.append(runtime_tensor) + input_layouts.append( + ttrt.runtime.testing.get_dram_interleaved_row_major_layout(runtime_dtype) + ) + + activations, weights = inputs_runtime + activations_layout, weights_layout = input_layouts + with DeviceContext([helper.query.device_ids[0]]) as device: + activations = ttrt.runtime.to_layout(activations, device, activations_layout) + weights = ttrt.runtime.to_layout(weights, device, weights_layout) + program_indices = list(range(helper.binary.get_num_programs())) + for program_index in program_indices: + program = helper.binary.get_program(program_index) + assert program.num_inputs() == 2 and program.num_outputs() == 1 + outputs = ttrt.runtime.submit( + device, helper.binary.fbb, program_index, [activations, weights] + ) + activations = ttrt.runtime.to_layout(outputs[0], device, activations_layout) + ttrt.runtime.deallocate_tensor(outputs[0], force=True) + activations = ttrt.runtime.to_host(activations, untilize=True) + ttrt.runtime.deallocate_tensor(weights, force=True) + + last_program: Binary.Program = helper.binary.get_program(program_indices[-1]) + torch_result_tensor = torch.randn( + last_program.program["outputs"][0]["desc"]["shape"], + dtype=Binary.Program.from_data_type( + last_program.program["outputs"][0]["desc"]["layout"]["memory_desc"][ + "data_type" + ] + ), + ) + ttrt.runtime.memcpy(torch_result_tensor.data_ptr(), activations) + golden = ( + (inputs_torch[0] + inputs_torch[1]).mul(inputs_torch[1]).sub(inputs_torch[1]) + ) + assert_pcc(golden, torch_result_tensor, threshold=0.99) + helper.teardown() diff --git a/runtime/test/python/ttnn/utils.py b/runtime/test/python/ttnn/utils.py new file mode 100644 index 0000000000..6596811fff --- /dev/null +++ b/runtime/test/python/ttnn/utils.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import ttrt +import ttrt.runtime +import torch +from ttrt.common.query import Query +from ttrt.common.util import * + +TT_MLIR_HOME = os.environ.get("TT_MLIR_HOME", "") + + +class Helper: + def __init__(self, logger=None): + self.artifacts_dir = f"{os.getcwd()}/ttrt-artifacts" + self.logger = logger if logger is not None else Logger() + self.logging = self.logger.get_logger() + self.file_manager = FileManager(self.logger) + self.artifacts = Artifacts( + self.logger, self.file_manager, artifacts_folder_path=self.artifacts_dir + ) + self.query = Query({"--quiet": True}, self.logger, self.artifacts) + self.query() + self.test_name = None + self.binary_path = None + self.binary = None + + def initialize(self, test_name, binary_path=None): + self.test_name = test_name + if binary_path: + self.binary_path = binary_path + self.binary = Binary(self.logger, self.file_manager, binary_path) + + def teardown(self): + self.test_name = None + self.binary_path = None + self.binary = None + + def check_constraints(self): + if not self.binary: + return + self.binary.check_version() + self.binary.check_system_desc(self.query) + + +class DeviceContext: + def __init__(self, device_ids): + self.device = ttrt.runtime.open_device(device_ids) + + def __enter__(self): + return self.device + + def __exit__(self, exc_type, exc_value, traceback): + ttrt.runtime.close_device(self.device) + + +def assert_tensors_match(tensor1, tensor2): + assert torch.allclose(tensor1, tensor2) + + +def assert_pcc(x, y, threshold=0.99): + combined = torch.stack([x.flatten(), y.flatten()]) + pcc = torch.corrcoef(combined)[0, 1].item() + assert pcc >= threshold, f"Expected pcc {pcc} >= {threshold}" diff --git a/runtime/test/ttnn/test_subtract.cpp b/runtime/test/ttnn/test_subtract.cpp index 00aebe20fb..995b95665c 100644 --- a/runtime/test/ttnn/test_subtract.cpp +++ b/runtime/test/ttnn/test_subtract.cpp @@ -21,12 +21,13 @@ TEST(TTNNSubtract, Equal) { const char *fbPath = std::getenv("TTMLIR_SUBTRACT_FB_PATH"); assert(fbPath && "Path to subtract flatbuffer must be provided"); ::tt::runtime::Binary fbb = ::tt::runtime::Binary::loadFromPath(fbPath); - EXPECT_EQ(fbb.getFileIdentifier(), "TTNN"); + ASSERT_EQ(fbb.getFileIdentifier(), "TTNN"); ::tt::runtime::setCompatibleRuntime(fbb); std::vector<::tt::runtime::TensorDesc> inputDescs = fbb.getProgramInputs(0); + assert(inputDescs.size() == 2); std::vector<::tt::runtime::TensorDesc> outputDescs = fbb.getProgramOutputs(0); - std::vector<::tt::runtime::Tensor> inputTensors, outputTensors; - + assert(outputDescs.size() == 1); + std::vector<::tt::runtime::Tensor> inputTensors; std::uint32_t tensorSize = inputDescs[0].itemsize; for (const int dim : inputDescs[0].shape) { tensorSize *= dim; @@ -38,26 +39,27 @@ TEST(TTNNSubtract, Equal) { std::memset(data.get(), 1, tensorSize); inputTensors.emplace_back(::tt::runtime::createTensor(data, desc)); } - for (const auto &desc : outputDescs) { - std::shared_ptr data = - ::tt::runtime::utils::malloc_shared(tensorSize); - // Set to wrong value on purpose here - std::memset(data.get(), 1, tensorSize); - outputTensors.emplace_back(::tt::runtime::createTensor(data, desc)); - } + + std::shared_ptr outputDataPtr = + ::tt::runtime::utils::malloc_shared(tensorSize); + // Set to wrong value on purpose here + std::memset(outputDataPtr.get(), 1, tensorSize); + ::tt::runtime::Tensor outputTensor = + ::tt::runtime::createTensor(outputDataPtr, outputDescs[0]); size_t numDevices = ::tt::runtime::getNumAvailableDevices(); std::vector deviceIds(numDevices); std::iota(deviceIds.begin(), deviceIds.end(), 0); - auto device = ::tt::runtime::openDevice(deviceIds); - auto ev = ::tt::runtime::submit(device, fbb, 0, inputTensors, outputTensors); + auto device = ::tt::runtime::openDevice({deviceIds[0]}); + std::vector<::tt::runtime::Tensor> output = + ::tt::runtime::submit(device, fbb, 0, inputTensors); ::tt::runtime::closeDevice(device); - + assert(output.size() == 1); std::shared_ptr expected = ::tt::runtime::utils::malloc_shared(tensorSize); std::memset(expected.get(), 0, tensorSize); - for (const auto &outputTensor : outputTensors) { - EXPECT_EQ(std::memcmp(outputTensor.data.get(), expected.get(), tensorSize), - 0); - } + ::tt::runtime::Tensor submitOutput = output[0]; + ASSERT_NE(std::memcmp(outputDataPtr.get(), expected.get(), tensorSize), 0); + ::tt::runtime::memcpy(outputTensor, submitOutput); + ASSERT_EQ(std::memcmp(outputDataPtr.get(), expected.get(), tensorSize), 0); } diff --git a/runtime/tools/python/CMakeLists.txt b/runtime/tools/python/CMakeLists.txt index a4c7a51916..966ee9681a 100644 --- a/runtime/tools/python/CMakeLists.txt +++ b/runtime/tools/python/CMakeLists.txt @@ -9,9 +9,11 @@ add_custom_target(ttrt COMMAND TTMLIR_ENABLE_RUNTIME=${TTMLIR_ENABLE_RUNTIME} TT_RUNTIME_ENABLE_TTNN=${TT_RUNTIME_ENABLE_TTNN} TT_RUNTIME_ENABLE_TTMETAL=${TT_RUNTIME_ENABLE_TTMETAL} + TTMLIR_ENABLE_RUNTIME_TESTS=${TTMLIR_ENABLE_RUNTIME_TESTS} TT_RUNTIME_ENABLE_PERF_TRACE=${TT_RUNTIME_ENABLE_PERF_TRACE} TT_RUNTIME_DEBUG=${TT_RUNTIME_DEBUG} TT_RUNTIME_WORKAROUNDS=${TT_RUNTIME_WORKAROUNDS} + TTMLIR_BINARY_DIR=${TTMLIR_BINARY_DIR} TTMLIR_VERSION_MAJOR=${TTMLIR_VERSION_MAJOR} TTMLIR_VERSION_MINOR=${TTMLIR_VERSION_MINOR} TTMLIR_VERSION_PATCH=${TTMLIR_VERSION_PATCH} diff --git a/runtime/tools/python/requirements.txt b/runtime/tools/python/requirements.txt index 8bfab8347d..b427d78493 100644 --- a/runtime/tools/python/requirements.txt +++ b/runtime/tools/python/requirements.txt @@ -1 +1 @@ -torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu +torch==2.5.0 --index-url https://download.pytorch.org/whl/cpu diff --git a/runtime/tools/python/setup.py b/runtime/tools/python/setup.py index ddbe3da9fe..d754250e01 100644 --- a/runtime/tools/python/setup.py +++ b/runtime/tools/python/setup.py @@ -18,6 +18,11 @@ "SOURCE_ROOT", os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", ".."), ) +# Use 'src_dir/build' as default location if TTMLIR_BINARY_DIR env variable is not available. +ttmlir_build_dir = os.environ.get( + "TTMLIR_BINARY_DIR", + os.path.join(src_dir, "build"), +) toolchain = os.environ.get("TTMLIR_TOOLCHAIN_DIR", "/opt/ttmlir-toolchain") metaldir = f"{src_dir}/third_party/tt-metal/src/tt-metal-build" ttmetalhome = os.environ.get("TT_METAL_HOME", "") @@ -26,6 +31,7 @@ enable_runtime = os.environ.get("TTMLIR_ENABLE_RUNTIME", "OFF") == "ON" enable_ttnn = os.environ.get("TT_RUNTIME_ENABLE_TTNN", "OFF") == "ON" enable_ttmetal = os.environ.get("TT_RUNTIME_ENABLE_TTMETAL", "OFF") == "ON" +enable_runtime_tests = os.environ.get("TTMLIR_ENABLE_RUNTIME_TESTS", "OFF") == "ON" enable_perf = os.environ.get("TT_RUNTIME_ENABLE_PERF_TRACE", "OFF") == "ON" debug_runtime = os.environ.get("TT_RUNTIME_DEBUG", "OFF") == "ON" configure_workarounds_runtime = os.environ.get("TT_RUNTIME_WORKAROUNDS", "OFF") == "ON" @@ -37,12 +43,12 @@ include_dirs=[ f"{toolchain}/include", f"{src_dir}/runtime/include", - f"{src_dir}/build/include", - f"{src_dir}/build/include/ttmlir/Target/Common", + f"{ttmlir_build_dir}/include", + f"{ttmlir_build_dir}/include/ttmlir/Target/Common", ], libraries=["TTBinary", "flatbuffers"], library_dirs=[ - f"{src_dir}/build/runtime/lib", + f"{ttmlir_build_dir}/runtime/lib", f"{toolchain}/lib", ], define_macros=[("VERSION_INFO", __version__)], @@ -59,14 +65,22 @@ linklibs = ["TTBinary"] if enable_ttnn: runlibs += ["_ttnn.so"] - linklibs += ["TTRuntimeTTNN", "TTRuntimeTTNNOps", ":_ttnn.so"] + linklibs += [ + "TTRuntimeTTNN", + "TTRuntimeTTNNOps", + "TTRuntimeTTNNHelpers", + ":_ttnn.so", + ] + +if enable_ttnn and enable_runtime_tests: + linklibs += ["TTRuntimeTTNNTestHelpers"] if enable_ttmetal: runlibs += ["libtt_metal.so"] linklibs += ["TTRuntimeTTMetal", "tt_metal"] if enable_ttnn or enable_ttmetal: - runlibs += ["libdevice.so", "libnng.so.1", "libuv.so.1"] + runlibs += ["libdevice.so"] linklibs += ["TTRuntimeSysDesc", "TTRuntimeDebug", "TTRuntimeWorkarounds"] if enable_perf: @@ -80,13 +94,13 @@ for dylib in runlibs: shutil.copy( f"{metaldir}/lib/{dylib}", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime", ) command = [ "patchelf", "--set-rpath", "$ORIGIN", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/{dylib}", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/{dylib}", ] try: @@ -103,7 +117,7 @@ for dylib in perflibs: shutil.copy( f"{metaldir}/tools/profiler/bin/{dylib}", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime", ) shutil.copy( f"{metaldir}/tools/profiler/bin/{dylib}", @@ -169,7 +183,7 @@ def tt_metal_ignore_folders(folder, contents): # copy metal dir folder shutil.copytree( f"{ttmetalhome}/tt_metal", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/tt_metal", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/tt_metal", dirs_exist_ok=True, ignore=tt_metal_ignore_folders, ) @@ -177,14 +191,14 @@ def tt_metal_ignore_folders(folder, contents): # copy runtime dir folder shutil.copytree( f"{ttmetalhome}/runtime", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/runtime", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/runtime", dirs_exist_ok=True, ) # copy kernels shutil.copytree( f"{ttmetalhome}/ttnn", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/ttnn", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/ttnn", dirs_exist_ok=True, ) @@ -198,16 +212,16 @@ def package_files(directory): return paths extra_files_tt_metal = package_files( - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/tt_metal/" + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/tt_metal/" ) extra_files_runtime = package_files( - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/runtime/" + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/runtime/" ) extra_files_ttnn = package_files( - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/ttnn/" + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/ttnn/" ) extra_files_tests = package_files( - f"{src_dir}/build/runtime/tools/python/ttrt/runtime/tests/" + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime/tests/" ) metallibs += extra_files_tt_metal @@ -222,18 +236,19 @@ def package_files(directory): include_dirs=[ f"{toolchain}/include", f"{src_dir}/runtime/include", - f"{src_dir}/build/include", - f"{src_dir}/build/include/ttmlir/Target/Common", + f"{ttmlir_build_dir}/include", + f"{ttmlir_build_dir}/include/ttmlir/Target/Common", ], libraries=["TTRuntime"] + linklibs + ["flatbuffers"], library_dirs=[ - f"{src_dir}/build/runtime/lib", - f"{src_dir}/build/runtime/lib/common", - f"{src_dir}/build/runtime/lib/ttnn", - f"{src_dir}/build/runtime/lib/ttnn/operations", - f"{src_dir}/build/runtime/lib/ttmetal", + f"{ttmlir_build_dir}/runtime/lib", + f"{ttmlir_build_dir}/runtime/lib/common", + f"{ttmlir_build_dir}/runtime/lib/ttnn", + f"{ttmlir_build_dir}/runtime/lib/ttnn/operations", + f"{ttmlir_build_dir}/runtime/lib/ttmetal", + f"{ttmlir_build_dir}/runtime/test", f"{toolchain}/lib", - f"{src_dir}/build/runtime/tools/python/ttrt/runtime", + f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime", f"{metaldir}/lib", ], define_macros=[ @@ -243,6 +258,7 @@ def package_files(directory): "TT_RUNTIME_WORKAROUNDS", "1" if configure_workarounds_runtime else "0", ), + ("TTMLIR_ENABLE_RUNTIME_TESTS", "1" if enable_runtime_tests else "0"), ], ) ) diff --git a/runtime/tools/python/test/test_run.py b/runtime/tools/python/test/test_run.py index 69d5683aaf..37167a6e93 100644 --- a/runtime/tools/python/test/test_run.py +++ b/runtime/tools/python/test/test_run.py @@ -311,57 +311,6 @@ def test_enable_async_ttnn_cmd_run(): sub_process_command(command) -def test_disable_ignore_tile_shape_run(): - API.initialize_apis() - custom_args = {} - custom_args[ - "--result-file" - ] = f"ttrt-results/{inspect.currentframe().f_code.co_name}.json" - custom_args["binary"] = BINARY_FILE_PATH - custom_args["--disable-ignore-tile-shape"] = True - run_instance = API.Run(args=custom_args) - run_instance() - - -def test_disable_ignore_tile_shape_cmd_run(): - command = f"ttrt run {BINARY_FILE_PATH} --disable-ignore-tile-shape --log-file ttrt-results/{inspect.currentframe().f_code.co_name}.log --result-file ttrt-results/{inspect.currentframe().f_code.co_name}.json" - sub_process_command(command) - - -def test_disable_empty_op_row_major_run(): - API.initialize_apis() - custom_args = {} - custom_args[ - "--result-file" - ] = f"ttrt-results/{inspect.currentframe().f_code.co_name}.json" - custom_args["binary"] = BINARY_FILE_PATH - custom_args["--disable-empty-op-row-major"] = True - run_instance = API.Run(args=custom_args) - run_instance() - - -def test_disable_empty_op_row_major_cmd_run(): - command = f"ttrt run {BINARY_FILE_PATH} --disable-empty-op-row-major --log-file ttrt-results/{inspect.currentframe().f_code.co_name}.log --result-file ttrt-results/{inspect.currentframe().f_code.co_name}.json" - sub_process_command(command) - - -def test_disable_full_op_row_major_run(): - API.initialize_apis() - custom_args = {} - custom_args[ - "--result-file" - ] = f"ttrt-results/{inspect.currentframe().f_code.co_name}.json" - custom_args["binary"] = BINARY_FILE_PATH - custom_args["--disable-full-op-row-major"] = True - run_instance = API.Run(args=custom_args) - run_instance() - - -def test_disable_full_op_row_major_cmd_run(): - command = f"ttrt run {BINARY_FILE_PATH} --disable-full-op-row-major --log-file ttrt-results/{inspect.currentframe().f_code.co_name}.log --result-file ttrt-results/{inspect.currentframe().f_code.co_name}.json" - sub_process_command(command) - - def test_disable_maxpool2d_preshard_run(): API.initialize_apis() custom_args = {} diff --git a/runtime/tools/python/ttrt/common/perf.py b/runtime/tools/python/ttrt/common/perf.py index a341c2b4f4..55ee255f91 100644 --- a/runtime/tools/python/ttrt/common/perf.py +++ b/runtime/tools/python/ttrt/common/perf.py @@ -17,6 +17,7 @@ import atexit import traceback from pathlib import Path +import csv from ttrt.common.util import * from ttrt.common.query import Query @@ -456,6 +457,38 @@ def signal_handler(sig, frame): ) process_ops(None, None, False) + + # Add post-processing steps to insert location data into the ops_perf data file + with open(profiler_csv_file_path, "r") as perf_file: + perf_reader = csv.DictReader(perf_file) + headers = list(perf_reader.fieldnames) + ["LOC"] + perf_data = list(perf_reader) + + with open(profiler_csv_file_path, "w+") as perf_file, open( + tracy_ops_data_file_path, "r" + ) as message_file: + message_reader = csv.reader(message_file, delimiter=";") + ops_index = 0 + prev = None + for message in message_reader: + message = message[0] # Don't need timestamp information + if message.startswith("`"): + # This is a TTNN Message + # The location data is now in the previous message + # The order of data is maintained in perf_data so as the messages are received, they update the id last encountered. + # Now that we have a new message, we can update the location data from the previous message + if prev: + # Get the location data from the previous message and add it as new data for the perf_data (as a new col) + if len(perf_data) > ops_index: + perf_data[ops_index]["LOC"] = prev + ops_index += 1 + else: + prev = message + perf_writer = csv.DictWriter(perf_file, fieldnames=headers) + perf_writer.writeheader() + for row in perf_data: + perf_writer.writerow(row) + self.file_manager.copy_file( perf_folder_path, profiler_csv_file_path, diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index c2ae10ac9a..19ad61e241 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -125,39 +125,25 @@ def initialize_api(): help="enable async mode device execution for TTNN runtime", ) Run.register_arg( - name="--disable-ignore-tile-shape", - type=bool, - default=False, - choices=[True, False], - help="disable ignore tile shape workaround", - ) - Run.register_arg( - name="--disable-empty-op-row-major", - type=bool, - default=False, - choices=[True, False], - help="disable empty op force row major workaround", - ) - Run.register_arg( - name="--disable-full-op-row-major", + name="--disable-maxpool2d-preshard", type=bool, default=False, choices=[True, False], - help="disable full op force row major workaround", + help="disable maxpool2d preshard workaround", ) Run.register_arg( - name="--disable-maxpool2d-preshard", + name="--disable-swap-binary-operands", type=bool, default=False, choices=[True, False], - help="disable maxpool2d preshard workaround", + help="disable swap binary operands workaround", ) Run.register_arg( - name="--disable-swap-binary-operands", + name="--disable-read-update-index-for-kv-cache", type=bool, default=False, choices=[True, False], - help="disable swap binary operands workaround", + help="disable read update index for kv cache workaround", ) Run.register_arg( name="--result-file", @@ -370,16 +356,15 @@ def _execute(binaries): ) self.logging.debug(f"setting tt runtime debug env={debug_env}") workaround_env = ttrt.runtime.WorkaroundEnv.get( - not self["--disable-ignore-tile-shape"], - not self["--disable-empty-op-row-major"], - not self["--disable-full-op-row-major"], not self["--disable-maxpool2d-preshard"], not self["--disable-swap-binary-operands"], + not self["--disable-read-update-index-for-kv-cache"], ) self.logging.debug(f"setting tt runtime workaround env={workaround_env}") self.logging.debug(f"setting torch manual seed={self['--seed']}") torch.manual_seed(self["--seed"]) ttrt.runtime.set_compatible_runtime(binaries[0].fbb) + current_runtime = ttrt.runtime.get_current_runtime() self.logging.debug(f"opening devices={self.query.device_ids}") device = ttrt.runtime.open_device(self.query.device_ids) @@ -459,20 +444,43 @@ def _execute(binaries): self.logging.debug( f"starting loop={loop+1}/{self['--loops']} for binary={bin.file_path}" ) + if ( + current_runtime + == ttrt.runtime.DeviceRuntime.TTMetal + ): + event = ttrt.runtime.submit( + device, + bin.fbb, + program_index, + total_inputs[loop], + total_outputs[loop], + ) - event = ttrt.runtime.submit( - device, - bin.fbb, - program_index, - total_inputs[loop], - total_outputs[loop], - ) + elif current_runtime == ttrt.runtime.DeviceRuntime.TTNN: + runtime_outputs = ttrt.runtime.submit( + device, + bin.fbb, + program_index, + total_inputs[loop], + ) + ttrt.runtime.wait(runtime_outputs) + for i, runtime_output_tensor in enumerate( + runtime_outputs + ): + ttrt.runtime.memcpy( + total_outputs[loop][i], + runtime_output_tensor, + ) + ttrt.runtime.deallocate_tensor( + runtime_output_tensor, force=True + ) self.logging.debug( f"finished loop={loop+1}/{self['--loops']} for binary={bin.file_path}" ) - ttrt.runtime.wait(event) + if event is not None: + ttrt.runtime.wait(event) if self["--identity"]: self.logging.debug( diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index 370643e7d3..45e0a9db95 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -586,6 +586,12 @@ def __init__(self, index, program): self.input_tensors = [] self.output_tensors = [] + def num_inputs(self): + return len(self.program["inputs"]) + + def num_outputs(self): + return len(self.program["outputs"]) + def populate_inputs(self, init_fn, golden_inputs=[]): if len(golden_inputs) > 0: assert len(golden_inputs) == len(self.program["inputs"]) diff --git a/runtime/tools/python/ttrt/runtime/__init__.py b/runtime/tools/python/ttrt/runtime/__init__.py index 642b0401f5..0376c07b58 100644 --- a/runtime/tools/python/ttrt/runtime/__init__.py +++ b/runtime/tools/python/ttrt/runtime/__init__.py @@ -12,19 +12,33 @@ DebugEnv, DebugHooks, get_current_runtime, + set_current_runtime, set_compatible_runtime, get_current_system_desc, open_device, close_device, submit, create_tensor, + create_empty_tensor, create_multi_device_tensor, wait, + to_host, + to_layout, + get_layout, get_op_output_tensor, get_op_debug_str, + memcpy, + deallocate_tensor, WorkaroundEnv, ) except ModuleNotFoundError: raise ImportError( "Error: Project was not built with runtime enabled, rebuild with: -DTTMLIR_ENABLE_RUNTIME=ON" ) + +try: + from ._C import testing +except ImportError: + print( + "Warning: not importing testing submodule since project was not built with runtime testing enabled. To enable, rebuild with: -DTTMLIR_ENABLE_RUNTIME_TESTS=ON" + ) diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index dfc4a68201..47b42eab56 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -8,6 +8,9 @@ #include "tt/runtime/detail/workarounds.h" #include "tt/runtime/runtime.h" #include "tt/runtime/utils.h" +#if defined(TTMLIR_ENABLE_RUNTIME_TESTS) && TTMLIR_ENABLE_RUNTIME_TESTS == 1 +#include "tt/runtime/test/utils.h" +#endif #include #include @@ -22,6 +25,7 @@ PYBIND11_MODULE(_C, m) { .def("deallocate_buffers", &tt::runtime::detail::deallocateBuffers); py::class_(m, "Event"); py::class_(m, "Tensor"); + py::class_(m, "Layout"); py::class_(m, "OpContext"); py::class_(m, "CallbackContext"); py::enum_<::tt::target::DataType>(m, "DataType") @@ -48,6 +52,8 @@ PYBIND11_MODULE(_C, m) { m.def("set_compatible_runtime", &tt::runtime::setCompatibleRuntime, py::arg("binary"), "Set the backend device runtime type to match the binary"); + m.def("set_current_runtime", &tt::runtime::setCurrentRuntime, + py::arg("runtime"), "Set the backend device runtime type"); m.def("get_current_system_desc", &tt::runtime::getCurrentSystemDesc, "Get the current system descriptor"); m.def( @@ -61,6 +67,15 @@ PYBIND11_MODULE(_C, m) { shape, stride, itemsize, dataType); }, "Create a tensor with borrowed memory"); + m.def( + "create_empty_tensor", + [](::tt::runtime::Device device, ::tt::runtime::Layout layout, + std::vector const &shape, + std::vector const &stride, std::uint32_t itemsize) { + return tt::runtime::createTensor(device, layout, shape, stride, + itemsize); + }, + "Create an empty tensor with the specified layout"); m.def( "create_multi_device_tensor", [](std::vector &ptrs, @@ -69,8 +84,8 @@ PYBIND11_MODULE(_C, m) { ::tt::target::DataType dataType, std::unordered_map const &strategy) { std::vector> data; - data.resize(ptrs.size()); - std::transform(ptrs.begin(), ptrs.end(), data.begin(), + data.reserve(ptrs.size()); + std::transform(ptrs.begin(), ptrs.end(), std::back_inserter(data), [](std::uintptr_t ptr) { return ::tt::runtime::utils::unsafe_borrow_shared( reinterpret_cast(ptr)); @@ -85,10 +100,50 @@ PYBIND11_MODULE(_C, m) { py::arg("num_hw_cqs") = size_t{1}, "Open a mesh of devices for execution"); m.def("close_device", &tt::runtime::closeDevice, "Close a mesh device"); - m.def("submit", &tt::runtime::submit, py::arg("device"), - py::arg("executable"), py::arg("program_index"), py::arg("inputs"), - py::arg("outputs"), "Submit a binary for execution"); - m.def("wait", &tt::runtime::wait, py::arg("event")); + m.def("to_host", &tt::runtime::toHost, py::arg("tensor"), + py::arg("untilize") = false, "Copy the tensor to the host"); + m.def("to_layout", &tt::runtime::toLayout, py::arg("tensor"), + py::arg("device"), py::arg("layout"), + "Create a copy of the tensor with the specified layout"); + m.def("get_layout", &tt::runtime::getLayout, py::arg("executable"), + py::arg("program_index"), py::arg("input_index"), + "Get the layout of the input tensor"); + m.def( + "submit", + [](::tt::runtime::Device device, ::tt::runtime::Binary executable, + std::uint32_t programIndex, + const std::vector<::tt::runtime::Tensor> &inputs) + -> std::vector<::tt::runtime::Tensor> { + return ::tt::runtime::submit(device, executable, programIndex, inputs); + }, + py::arg("device"), py::arg("executable"), py::arg("program_index"), + py::arg("inputs"), + "Submit a ttnn binary for execution, returns a vector of output tensors"); + m.def( + "submit", + [](::tt::runtime::Device device, ::tt::runtime::Binary executable, + std::uint32_t programIndex, + const std::vector<::tt::runtime::Tensor> &inputs, + const std::vector<::tt::runtime::Tensor> &outputs) + -> ::tt::runtime::Event { + return ::tt::runtime::submit(device, executable, programIndex, inputs, + outputs); + }, + py::arg("device"), py::arg("executable"), py::arg("program_index"), + py::arg("inputs"), py::arg("outputs"), + "Submit a ttmetal binary for execution. returns event wrapper"); + m.def( + "wait", [](::tt::runtime::Event event) { ::tt::runtime::wait(event); }, + py::arg("event")); + m.def( + "wait", [](::tt::runtime::Tensor tensor) { ::tt::runtime::wait(tensor); }, + py::arg("tensor")); + m.def( + "wait", + [](const std::vector<::tt::runtime::Tensor> &tensors) { + ::tt::runtime::wait(tensors); + }, + py::arg("tensors")); m.def( "get_op_output_tensor", [](tt::runtime::OpContext &opContextHandle, @@ -100,7 +155,25 @@ PYBIND11_MODULE(_C, m) { "Get the input tensor of the op"); m.def("get_op_debug_str", &tt::runtime::getOpDebugString, "Get the debug string of the op"); - + m.def("get_op_loc_info", &tt::runtime::getOpLocInfo, + "Get the location info of the op"); + m.def( + "memcpy", + [](std::uintptr_t dst, ::tt::runtime::Tensor src) { + void *dstPtr = reinterpret_cast(dst); + ::tt::runtime::memcpy(dstPtr, src); + }, + py::arg("dst"), py::arg("src"), + "Copy the data from src tensor to dst pointer"); + m.def( + "memcpy", + [](::tt::runtime::Tensor dst, ::tt::runtime::Tensor src) { + ::tt::runtime::memcpy(dst, src); + }, + py::arg("dst"), py::arg("src"), + "Copy the data from src tensor to dst tensor"); + m.def("deallocate_tensor", &tt::runtime::deallocateTensor, py::arg("tensor"), + py::arg("force") = false, "Deallocate the tensor memory"); py::class_(m, "DebugEnv") .def_static("get", &tt::runtime::debug::Env::get) .def("__str__", [](const tt::runtime::debug::Env &env) { @@ -136,4 +209,17 @@ PYBIND11_MODULE(_C, m) { os << env; return os.str(); }); + +#if defined(TTMLIR_ENABLE_RUNTIME_TESTS) && TTMLIR_ENABLE_RUNTIME_TESTS == 1 + auto testing = m.def_submodule("testing"); + testing.def("get_dram_interleaved_tile_layout", + &tt::runtime::ttnn::test::getDramInterleavedTileLayout, + py::arg("dtype"), "Get dram interleaved tile layout"); + testing.def("get_dram_interleaved_row_major_layout", + &tt::runtime::ttnn::test::getDramInterleavedRowMajorLayout, + py::arg("dtype"), "Get dram interleaved row major layout"); + testing.def("get_host_row_major_layout", + &tt::runtime::ttnn::test::getHostRowMajorLayout, py::arg("dtype"), + "Get host row major layout"); +#endif } diff --git a/test/python/tensor_layout.py b/test/python/tensor_layout.py index 39a9a728be..2dbf249e9f 100644 --- a/test/python/tensor_layout.py +++ b/test/python/tensor_layout.py @@ -34,7 +34,7 @@ def createTensorLayout( shape, F32Type.get(ctx), None, Location.unknown(ctx) ) memoryLayout = getTensorMemoryLayout(memorySpace) - layout = tt.ir.LayoutAttr.get( + layout = tt.ir.MetalLayoutAttr.get( ctx, tensorTy, memorySpace, grid, collapseIntervals, oobVal, memoryLayout ) return RankedTensorType.get(shape, F32Type.get(ctx), layout, Location.unknown(ctx)) @@ -42,7 +42,7 @@ def createTensorLayout( def tilize(tensor, dataType, tileShape=[32, 32]): assert len(tileShape) == 2 - return tt.ir.LayoutAttr.with_element_type_( + return tt.ir.MetalLayoutAttr.with_element_type_( ctx, tensor.encoding, tt.ir.TileType.get(ctx, tileShape[0], tileShape[1], dataType), @@ -52,15 +52,15 @@ def tilize(tensor, dataType, tileShape=[32, 32]): def parallelize(tensor, grid, collapseIntervals=[(0, -1)]): if isinstance(grid, list) or isinstance(grid, tuple): grid = tt.ir.GridAttr.get(ctx, list(grid)) - return tt.ir.LayoutAttr.with_grid_( + return tt.ir.MetalLayoutAttr.with_grid_( ctx, tensor.encoding, tensor.shape, grid, collapseIntervals ) t0 = createTensorLayout([2, 3, 64, 128], [2, 4]) -# CHECK: tensor<2x3x64x128xf32, #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<192x32xf32, #tt.memory_space>, interleaved>> +# CHECK: tensor<2x3x64x128xf32, #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<192x32xf32, #tt.memory_space>, interleaved>> print(t0) -# CHECK: #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<6x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x4>, memref<6x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> print(tilize(t0, tt.DataType.BFP_BFloat8).wrapped()) print(parallelize(t0, [3, 2]).wrapped()) @@ -69,24 +69,24 @@ def parallelize(tensor, grid, collapseIntervals=[(0, -1)]): print(parallelize(t1, [3, 2]).wrapped()) t2 = createTensorLayout([128], [4], collapseIntervals=[(0, -1)]) -# CHECK: tensor<128xf32, #tt.layout<(d0) -> (d0), undef, <4>, memref<32xf32, #tt.memory_space>, interleaved>> +# CHECK: tensor<128xf32, #tt.metal_layout<(d0) -> (d0), undef, <4>, memref<32xf32, #tt.memory_space>, interleaved>> print(t2) -# CHECK: #tt.layout<(d0) -> (d0), undef, <2>, memref<64xf32, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (d0), undef, <2>, memref<64xf32, #tt.memory_space>, interleaved> print(parallelize(t2, [2]).wrapped()) -# CHECK: #tt.layout<(d0) -> (0, d0), undef, <1x2>, memref<1x64xf32, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (0, d0), undef, <1x2>, memref<1x64xf32, #tt.memory_space>, interleaved> print(parallelize(t2, [1, 2]).wrapped()) t3 = createTensorLayout([128], [1, 4], collapseIntervals=[(0, -1)]) -# CHECK: tensor<128xf32, #tt.layout<(d0) -> (0, d0), undef, <1x4>, memref<1x32xf32, #tt.memory_space>, interleaved>> +# CHECK: tensor<128xf32, #tt.metal_layout<(d0) -> (0, d0), undef, <1x4>, memref<1x32xf32, #tt.memory_space>, interleaved>> print(t3) -# CHECK: #tt.layout<(d0) -> (0, d0), undef, <1x4>, memref<1x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (0, d0), undef, <1x4>, memref<1x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> print(tilize(t3, tt.DataType.BFP_BFloat8).wrapped()) t4 = createTensorLayout([128], [1, 2, 4], collapseIntervals=[(0, -1)]) -# CHECK: tensor<128xf32, #tt.layout<(d0) -> (0, 0, d0), undef, <1x2x4>, memref<1x1x32xf32, #tt.memory_space>, interleaved>> +# CHECK: tensor<128xf32, #tt.metal_layout<(d0) -> (0, 0, d0), undef, <1x2x4>, memref<1x1x32xf32, #tt.memory_space>, interleaved>> print(t4) -# CHECK: #tt.layout<(d0) -> (0, 0, d0), undef, <1x2x4>, memref<1x1x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (0, 0, d0), undef, <1x2x4>, memref<1x1x1x!tt.tile<32x32, bfp_bf8>, #tt.memory_space>, interleaved> print(tilize(t4, tt.DataType.BFP_BFloat8).wrapped()) -# CHECK: #tt.layout<(d0) -> (0, d0), undef, <1x2>, memref<1x64xf32, #tt.memory_space>, interleaved> +# CHECK: #tt.metal_layout<(d0) -> (0, d0), undef, <1x2>, memref<1x64xf32, #tt.memory_space>, interleaved> print(parallelize(t4, [1, 2]).wrapped()) diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir index fa6cbb4236..42a26ad15f 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir @@ -8,3 +8,54 @@ module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replic return %1 : tensor<512x512xf32> } } + +module { + func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2, 3] : (tensor<1x23x40x1xf32>) -> tensor<1x23x40x128xf32> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [3] : (tensor<128xf32>) -> tensor<1x23x40x128xf32> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + %2 = stablehlo.divide %0, %1 : tensor<1x23x40x128xf32> + return %2 : tensor<1x23x40x128xf32> + } +} + +module { + func.func @main(%arg0: tensor<32xi64>, %arg1: tensor<32x1xi64>) -> tensor<32x32xi1> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<32xi64>) -> tensor<32x32xi64> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<32x1xi64>) -> tensor<32x32xi64> + %2 = stablehlo.compare GT, %0, %1, SIGNED : (tensor<32x32xi64>, tensor<32x32xi64>) -> tensor<32x32xi1> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %2 : tensor<32x32xi1> + } +} + +module { + func.func @main(%arg0: tensor<16x1xf32>, %arg1: tensor<1x1x32xi64>) -> tensor<1x16x32xf32> { + %0 = stablehlo.convert %arg1 : (tensor<1x1x32xi64>) -> tensor<1x1x32xf32> + %1 = stablehlo.broadcast_in_dim %arg0, dims = [1, 2] : (tensor<16x1xf32>) -> tensor<1x16x32xf32> + %2 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2] : (tensor<1x1x32xf32>) -> tensor<1x16x32xf32> + %3 = stablehlo.multiply %1, %2 : tensor<1x16x32xf32> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %3 : tensor<1x16x32xf32> + } +} + +module { + func.func @main(%arg0: tensor<1x10xi64>, %arg1: tensor<10x1xi64>) -> tensor<10x10xi64> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x10xi64>) -> tensor<10x10xi64> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<10x1xi64>) -> tensor<10x10xi64> + %2 = stablehlo.subtract %0, %1 : tensor<10x10xi64> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %2 : tensor<10x10xi64> + } +} + +module { + func.func @main(%arg0: tensor<8xf32>, %arg1: tensor<1xf32>) -> tensor<8xf32> { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<8xf32>) -> tensor<8xf32> + %1 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<1xf32>) -> tensor<8xf32> + %2 = stablehlo.add %0, %1 : tensor<8xf32> + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %2 : tensor<8xf32> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir new file mode 100644 index 0000000000..5fbab794c6 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl_ops.mlir @@ -0,0 +1,83 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s + +// jax/pjrt sharding target 1x2 for n300 +module @jit_matmul_basic attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<8192x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<784x16384xf32> {mhlo.layout_mode = "default"}) -> (tensor<8192x16384xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,2]<=[2]}"} : (tensor<8192x784xf32>) -> tensor<8192x784xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x784xf32>) -> tensor<8192x392xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + %2 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{devices=[2,1]<=[2]}"} : (tensor<784x16384xf32>) -> tensor<784x16384xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<784x16384xf32>) -> tensor<392x16384xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + %4 = call @shmap_body(%1, %3) : (tensor<8192x392xf32>, tensor<392x16384xf32>) -> tensor<8192x16384xf32> + %5 = stablehlo.custom_call @Sharding(%4) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + %6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + return %6 : tensor<8192x16384xf32> + } + func.func private @shmap_body(%arg0: tensor<8192x392xf32>, %arg1: tensor<392x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = "[('x',), None]"}) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<8192x392xf32>, tensor<392x16384xf32>) -> tensor<8192x16384xf32> + %1 = "stablehlo.all_reduce"(%0) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, use_global_device_ids}> ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %2 = stablehlo.add %arg2, %arg3 : tensor + stablehlo.return %2 : tensor + }) : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + // CHECK: %[[C:.*]] = "ttir.all_reduce"[[C:.*]] + return %1 : tensor<8192x16384xf32> + } +} + +// jax/pjrt sharding target 2x4 for t3k +module @jit_matmul_basic2 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<8192x784xf32>, %arg1: tensor<784x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[2,4]<=[8]}"} : (tensor<8192x784xf32>) -> tensor<8192x784xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x784xf32>) -> tensor<4096x196xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + %2 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{devices=[4,1,2]<=[2,4]T(1,0) last_tile_dim_replicate}"} : (tensor<784x16384xf32>) -> tensor<784x16384xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<784x16384xf32>) -> tensor<196x16384xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + %4 = call @shmap_body(%1, %3) : (tensor<4096x196xf32>, tensor<196x16384xf32>) -> tensor<4096x16384xf32> + %5 = stablehlo.custom_call @Sharding(%4) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<4096x16384xf32>) -> tensor<4096x16384xf32> + %6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {backend_config = "", mhlo.sharding = "{devices=[2,1,4]<=[8] last_tile_dim_replicate}"} : (tensor<4096x16384xf32>) -> tensor<8192x16384xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + return %6 : tensor<8192x16384xf32> + } + func.func private @shmap_body(%arg0: tensor<4096x196xf32>, %arg1: tensor<196x16384xf32>) -> (tensor<4096x16384xf32> {jax.result_info = "[('x',), None]"}) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4096x196xf32>, tensor<196x16384xf32>) -> tensor<4096x16384xf32> + %1 = "stablehlo.all_reduce"(%0) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %2 = stablehlo.add %arg2, %arg3 : tensor + stablehlo.return %2 : tensor + }) : (tensor<4096x16384xf32>) -> tensor<4096x16384xf32> + // CHECK: %[[C:.*]] = "ttir.all_reduce"[[C:.*]] + return %1 : tensor<4096x16384xf32> + } +} + +// jax/pjrt sharding target 1x8 for t3k +module @jit_matmul_basic3 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<8192x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<784x16384xf32> {mhlo.layout_mode = "default"}) -> (tensor<8192x16384xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,8]<=[8]}"} : (tensor<8192x784xf32>) -> tensor<8192x784xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x784xf32>) -> tensor<8192x98xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + %2 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<784x16384xf32>) -> tensor<784x16384xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<784x16384xf32>) -> tensor<98x16384xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + %4 = call @shmap_body(%1, %3) : (tensor<8192x98xf32>, tensor<98x16384xf32>) -> tensor<8192x16384xf32> + %5 = stablehlo.custom_call @Sharding(%4) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + %6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + // CHECK: %[[C:.*]] = "ttir.mesh_shard"[[C:.*]] + return %6 : tensor<8192x16384xf32> + } + func.func private @shmap_body(%arg0: tensor<8192x98xf32>, %arg1: tensor<98x16384xf32>) -> (tensor<8192x16384xf32> {jax.result_info = "[('x',), None]"}) { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<8192x98xf32>, tensor<98x16384xf32>) -> tensor<8192x16384xf32> + %1 = "stablehlo.all_reduce"(%0) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, use_global_device_ids}> ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %2 = stablehlo.add %arg2, %arg3 : tensor + stablehlo.return %2 : tensor + }) : (tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + // CHECK: %[[C:.*]] = "ttir.all_reduce"[[C:.*]] + return %1 : tensor<8192x16384xf32> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/conv2d_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/conv2d_op.mlir index ce4a6f6565..0c41398cd1 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/conv2d_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/conv2d_op.mlir @@ -2,6 +2,8 @@ // RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s module @jit_convolution attributes {} { func.func public @test_convolution(%arg0: tensor<1x128x128x32xf32>, %arg1: tensor<64x32x3x3xf32>) -> tensor<1x128x128x64xf32> { + // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.convolution"[[C:.*]] %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[o, i, 0, 1]->[b, 0, 1, f], window = { @@ -12,8 +14,26 @@ module @jit_convolution attributes {} { batch_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] } : (tensor<1x128x128x32xf32>, tensor<64x32x3x3xf32>) -> tensor<1x128x128x64xf32> - // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.convolution"[[C:.*]] return %0 : tensor<1x128x128x64xf32> } + + // Tests 1d convolution that gets translated to 2d. + func.func @test_convolution_1d(%arg0: tensor<1x256x512xf32>, %arg1: tensor<1024x256x1xf32>) -> tensor<1x1024x512xf32> { + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: %1 = "ttir.convolution"(%arg0, %arg1, [[VAL0]]) + // CHECK: batch_group_count = 1 : i64, convolution_layout = #ttir, weight_dilation = array, window_reversal = array, window_strides = array + // CHECK: : (tensor<1x256x512xf32>, tensor<1024x256x1xf32>, [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, f, 0]x[o, i, 0]->[b, f, 0], + window = { + stride = [1], + pad = [[0, 0]], + rhs_dilate = [1] + } { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<1x256x512xf32>, tensor<1024x256x1xf32>) -> tensor<1x1024x512xf32> + return %0 : tensor<1x1024x512xf32> + } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/dot_general_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_2d.mlir similarity index 100% rename from test/ttmlir/Conversion/StableHLOToTTIR/dot_general_op.mlir rename to test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_2d.mlir diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_3d.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_3d.mlir new file mode 100644 index 0000000000..52e2d80016 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/dot_general/dot_general_3d.mlir @@ -0,0 +1,10 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module { + func.func @main(%arg0: tensor<8x1x920xbf16>, %arg1: tensor<8x100x32xbf16>, %arg2: tensor<8x32x920xbf16>) -> tensor<8x100x920xbf16> { + %0 = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2] : (tensor<8x32x920xbf16>) -> tensor<8x32x920xbf16> + // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] + %1 = stablehlo.dot_general %arg1, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<8x100x32xbf16>, tensor<8x32x920xbf16>) -> tensor<8x100x920xbf16> + return %1 : tensor<8x100x920xbf16> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/dynamic_iota_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/dynamic_iota_op.mlir new file mode 100644 index 0000000000..43241ac6f0 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/dynamic_iota_op.mlir @@ -0,0 +1,11 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_dnamic_iota attributes {} { + func.func public @test_dynamic_iota() -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttir.arange"[[C:.*]] + %output_shape = stablehlo.constant dense<[1, 32, 128, 128]> : tensor<4xi64> + %0 = "stablehlo.dynamic_iota"(%output_shape) {iota_dimension = 1: i64} : (tensor<4xi64>) -> tensor<1x32x128x128xf32> + return %0 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir index ba29d123e8..e80bb75886 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir @@ -8,6 +8,7 @@ module @jit_gather attributes {} { // CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]] return %0 : tensor<1x32x1024xf32> } + func.func public @test_gather_1(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xf32> { %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<448x384xf32>, tensor<1x2x1xi32>) -> tensor<1x2x384xf32> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] @@ -22,4 +23,20 @@ module @jit_gather attributes {} { return %0 : tensor<1x2x384xf32> } + func.func public @test_gather_3(%arg0: tensor<32128x512xbf16>, %arg1: tensor<1x15xi64>) -> tensor<1x15x512xbf16> { + // CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : tensor<1x15x512xbf16> + // CHECK: %[[VAL:[0-9]+]] = "ttir.gather"(%arg0, %arg1, %[[EMPTY]]) + // CHECK-SAME: collapsed_slice_dims = array, + // CHECK-SAME: index_vector_dim = 2 : si64, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: offset_dims = array, + // CHECK-SAME: operand_batching_dims = array, + // CHECK-SAME: slice_sizes = array, + // CHECK-SAME: start_index_map = array, + // CHECK-SAME: start_indices_batching_dims = array + // CHECK-SAME: (tensor<32128x512xbf16>, tensor<1x15xi32>, tensor<1x15x512xbf16>) -> tensor<1x15x512xbf16> + %0 = "stablehlo.gather"(%arg0, %arg1) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<32128x512xbf16>, tensor<1x15xi64>) -> tensor<1x15x512xbf16> + // CEHCK: return %[[VAL]] : tensor<1x15x512xbf16> + return %0 : tensor<1x15x512xbf16> + } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/iota_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/iota_op.mlir new file mode 100644 index 0000000000..857a621bb0 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/iota_op.mlir @@ -0,0 +1,10 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_iota attributes {} { + func.func public @test_iota() -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttir.arange"[[C:.*]] + %0 = "stablehlo.iota"() {iota_dimension = 1: i64} : () -> tensor<1x32x128x128xf32> + return %0 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir new file mode 100644 index 0000000000..92cd8895fd --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir @@ -0,0 +1,16 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_scatter attributes {} { + func.func public @test_scatter(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x1xi64>, %arg2: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE1:tensor<[0-9]+x[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + %result = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<1x3x320x320xf32>, tensor<1x1xi64>, tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> + // CHECK: [[VAL1:%[0-9]+]] = "ttir.scatter"(%arg0, %arg1, %arg2, [[VAL0]]) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile], scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array} + // CHECK: ([[TENSOR_SIZE1]], tensor<1x1xi32>, tensor<1x3x32x32xf32>, [[TENSOR_SIZE1]]) -> tensor<1x3x320x320xf32> + return %result : tensor<1x3x320x320xf32> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE1]] + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/log_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/log_op.mlir new file mode 100644 index 0000000000..702bc155da --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/log_op.mlir @@ -0,0 +1,10 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_eltwise_log attributes {} { + func.func public @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = stablehlo.log %arg0 : tensor<13x21x3xf32> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.log"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/logit_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/logit_op.mlir new file mode 100644 index 0000000000..48c64d12d4 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/logit_op.mlir @@ -0,0 +1,10 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_eltwise_logit attributes {} { + func.func public @test_logit(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = stablehlo.logistic %arg0 : tensor<13x21x3xf32> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.sigmoid"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/tan_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/tan_op.mlir new file mode 100644 index 0000000000..77b8f3b8bc --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/tan_op.mlir @@ -0,0 +1,10 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_eltwise_tan attributes {} { + func.func public @test_tan(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = stablehlo.tan %arg0 : tensor<13x21x3xf32> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.tan"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/tanh_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/tanh_op.mlir new file mode 100644 index 0000000000..5d420c43c5 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/tanh_op.mlir @@ -0,0 +1,10 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_eltwise_tanh attributes {} { + func.func public @test_tanh(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = stablehlo.tanh %arg0 : tensor<13x21x3xf32> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.tanh"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/clamp.mlir b/test/ttmlir/Conversion/TosaToTTIR/clamp.mlir new file mode 100644 index 0000000000..0444fbcffa --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/clamp.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.clamp %arg0 { min_int = 2 : i64, max_int = 3 : i64, min_fp = 2.0 : f32, max_fp = 3.0 : f32 } : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.clamp"(%arg{{[0-9]+}}, %[[OP_OUT]]) + // CHECK-SAME: max = 3.000000e+00 : f32, min = 2.000000e+00 : f32{{.+}}: ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + // CHECK: return %[[VAL]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/compare/equal.mlir b/test/ttmlir/Conversion/TosaToTTIR/compare/equal.mlir new file mode 100644 index 0000000000..20387a6f1a --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/compare/equal.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> { + %0 = tosa.equal %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xi1>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.eq"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, [[VAL0]]){{.+}}: (tensor<13x21x3xf32>, tensor<13x21x3xf32>, [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xi1> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/compare/greater.mlir b/test/ttmlir/Conversion/TosaToTTIR/compare/greater.mlir new file mode 100644 index 0000000000..7487492997 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/compare/greater.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_greater(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> { + %0 = tosa.greater %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xi1>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.gt"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, [[VAL0]]){{.+}}: (tensor<13x21x3xf32>, tensor<13x21x3xf32>, [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xi1> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/compare/greater_equal.mlir b/test/ttmlir/Conversion/TosaToTTIR/compare/greater_equal.mlir new file mode 100644 index 0000000000..479af38156 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/compare/greater_equal.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_greater_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> { + %0 = tosa.greater_equal %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xi1>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.ge"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, [[VAL0]]){{.+}}: (tensor<13x21x3xf32>, tensor<13x21x3xf32>, [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xi1> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/add.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/add.mlir new file mode 100644 index 0000000000..b16e8e40ce --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/add.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_add(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.add %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xf32>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.add"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, [[VAL0]]){{.+}}: ([[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/maximum.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/maximum.mlir new file mode 100644 index 0000000000..66691e2f07 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/maximum.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_maximum(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.maximum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.maximum"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: return %[[VAL]] : [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/minimum.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/minimum.mlir new file mode 100644 index 0000000000..7bfb100927 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/minimum.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_minimum(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.minimum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.minimum"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: return %[[VAL]] : [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/tosa_to_ttir_multiply.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/mul.mlir similarity index 53% rename from test/ttmlir/Dialect/TTIR/tosa_to_ttir_multiply.mlir rename to test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/mul.mlir index fd35f0cd10..137939fcf8 100644 --- a/test/ttmlir/Dialect/TTIR/tosa_to_ttir_multiply.mlir +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/mul.mlir @@ -1,10 +1,10 @@ // RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> - // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xf32>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.multiply"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, [[VAL0]]){{.+}}: ([[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] return %0 : tensor<13x21x3xf32> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] } } diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/sub.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/sub.mlir new file mode 100644 index 0000000000..5f8f5bf849 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_binary/sub.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_sub(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.sub %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.subtract"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: return %[[VAL]] : [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_ternary/select.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_ternary/select.mlir new file mode 100644 index 0000000000..2e02be5ebf --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_ternary/select.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_select(%arg0: tensor<32x128xi1>, %arg1: tensor<32x128xf32>, %arg2: tensor<32x128xf32>) -> tensor<32x128xf32> { + // CHECK: func.func {{.+}} [[SELECTOR:tensor<[0-9]+x[0-9]+xi1>]] + %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x128xi1>, tensor<32x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.where"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, %arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[SELECTOR]], [[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: return %[[VAL]] : [[TENSOR_SIZE]] + return %0 : tensor<32x128xf32> + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/abs.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/abs.mlir new file mode 100644 index 0000000000..9df5a2828b --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/abs.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.abs %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xf32>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.abs"(%arg{{[0-9]+}}, [[VAL0]]){{.+}}: ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/cast.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/cast.mlir new file mode 100644 index 0000000000..4ee3a742b6 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/cast.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_cast(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xbf16> { + %0 = tosa.cast %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xbf16> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xbf16>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.typecast"(%arg{{[0-9]+}}, [[VAL0]]){{.+}}: (tensor<13x21x3xf32>, [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xbf16> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/ceil.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/ceil.mlir new file mode 100644 index 0000000000..77dc60dc30 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/ceil.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.ceil %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xf32>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.ceil"(%arg{{[0-9]+}}, [[VAL0]]){{.+}}: ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/cos.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/cos.mlir new file mode 100644 index 0000000000..1a8aafd6b0 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/cos.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_cos(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.cos %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xf32>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.cos"(%arg{{[0-9]+}}, [[VAL0]]){{.+}}: ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/exp.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/exp.mlir new file mode 100644 index 0000000000..9575640211 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/exp.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.exp %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xf32>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.exp"(%arg{{[0-9]+}}, [[VAL0]]){{.+}}: ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/floor.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/floor.mlir new file mode 100644 index 0000000000..4653bfd3ee --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/floor.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.floor %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xf32>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.floor"(%arg{{[0-9]+}}, [[VAL0]]){{.+}}: ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/negate.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/negate.mlir new file mode 100644 index 0000000000..d1c294a848 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/negate.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.negate %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.neg"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: return %[[VAL]] : [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/reciprocal.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/reciprocal.mlir new file mode 100644 index 0000000000..ee3251eb63 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/reciprocal.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_reciprocal(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.reciprocal %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.reciprocal"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: return %[[VAL]] : [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/rsqrt.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/rsqrt.mlir new file mode 100644 index 0000000000..2475ffacd5 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/rsqrt.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.rsqrt %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.rsqrt"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: return %[[VAL]] : [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/sigmoid.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/sigmoid.mlir new file mode 100644 index 0000000000..18453f71aa --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/sigmoid.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.sigmoid %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.sigmoid"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: return %[[VAL]] : [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/sin.mlir b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/sin.mlir new file mode 100644 index 0000000000..017e9f366c --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/elementwise_unary/sin.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_sin(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.sin %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.sin"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: return %[[VAL]] : [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/logical/logical_and.mlir b/test/ttmlir/Conversion/TosaToTTIR/logical/logical_and.mlir new file mode 100644 index 0000000000..adab66f3a2 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/logical/logical_and.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> { + %0 = tosa.logical_and %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xi1>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.logical_and"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, [[VAL0]]){{.+}}: ([[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xi1> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/logical/logical_not.mlir b/test/ttmlir/Conversion/TosaToTTIR/logical/logical_not.mlir new file mode 100644 index 0000000000..ca74f1ab91 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/logical/logical_not.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_logical_not(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> { + %0 = tosa.logical_not %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xi1>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.logical_not"(%arg{{[0-9]+}}, [[VAL0]]){{.+}}: ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xi1> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/logical/logical_or.mlir b/test/ttmlir/Conversion/TosaToTTIR/logical/logical_or.mlir new file mode 100644 index 0000000000..4a4ab6eaef --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/logical/logical_or.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_logical_or(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> { + %0 = tosa.logical_or %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xi1>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.logical_or"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, [[VAL0]]){{.+}}: ([[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xi1> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/logical/logical_xor.mlir b/test/ttmlir/Conversion/TosaToTTIR/logical/logical_xor.mlir new file mode 100644 index 0000000000..6492691566 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/logical/logical_xor.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_logical_xor(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> { + %0 = tosa.logical_xor %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<13x21x3xi1>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.logical_xor"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, [[VAL0]]){{.+}}([[TENSOR_SIZE]], [[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xi1> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/matmul_op.mlir b/test/ttmlir/Conversion/TosaToTTIR/matmul_op.mlir new file mode 100644 index 0000000000..5e12ee0e0b --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/matmul_op.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_matmul(%arg0: tensor<13x21x16xf32>, %arg1: tensor<13x16x31xf32>) -> tensor<13x21x31xf32> { + // CHECK: func.func {{.+}}%arg{{[0-9]+}}: tensor<[[B:[0-9]+]]x[[I:[0-9]+]]x[[J:[0-9]+]]xf32>, %arg{{[0-9]+}}: tensor<[[B:[0-9]+]]x[[J:[0-9]+]]x[[K:[0-9]+]]xf32> + %0 = tosa.matmul %arg0, %arg1 : (tensor<13x21x16xf32>, tensor<13x16x31xf32>) -> tensor<13x21x31xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : tensor<[[B]]x[[I]]x[[K]]xf32> + // CHECK: %[[VAL:[0-9]+]] = "ttir.matmul"(%arg{{[0-9]+}}, %arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} (tensor<[[B]]x[[I]]x[[J]]xf32>, tensor<[[B]]x[[J]]x[[K]]xf32>, tensor<[[B]]x[[I]]x[[K]]xf32>) -> tensor<[[B]]x[[I]]x[[K]]xf32> + // CHECK: return %[[VAL]] : tensor<[[B]]x[[I]]x[[K]]xf32> + return %0 : tensor<13x21x31xf32> + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/maxpool2d_op.mlir b/test/ttmlir/Conversion/TosaToTTIR/maxpool2d_op.mlir new file mode 100644 index 0000000000..ff1ef5b4f6 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/maxpool2d_op.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_maxpool(%arg0: tensor<32x800x600x6xf32>) -> tensor<32x400x300x6xf32> { + // CHECK: func.func {{.+}} [[IN_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+x[0-9]+xf32>]]{{.*}} -> + %1 = tosa.max_pool2d %arg0 {kernel = array, pad = array, stride = array} : (tensor<32x800x600x6xf32>) -> tensor<32x400x300x6xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[OUT_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+x[0-9]+xf32>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.max_pool2d"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} ([[IN_SIZE]], [[OUT_SIZE]]) -> [[OUT_SIZE]] + // CHECK: return %[[VAL]] : [[OUT_SIZE]] + return %1 : tensor<32x400x300x6xf32> + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/reductions/max.mlir b/test/ttmlir/Conversion/TosaToTTIR/reductions/max.mlir new file mode 100644 index 0000000000..021d5fd08f --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/reductions/max.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_max(%arg0: tensor<13x21x3xf32>) -> tensor<13x1x3xf32> { + // CHECK: func.func {{.+}} [[IN_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf32>]]{{.*}} -> + %0 = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<13x21x3xf32>) -> tensor<13x1x3xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[OUT_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf32>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.max"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} ([[IN_SIZE]], [[OUT_SIZE]]) -> [[OUT_SIZE]] + // CHECK: return %[[VAL]] : [[OUT_SIZE]] + return %0 : tensor<13x1x3xf32> + } +} diff --git a/test/ttmlir/Conversion/TosaToTTIR/reductions/sum.mlir b/test/ttmlir/Conversion/TosaToTTIR/reductions/sum.mlir new file mode 100644 index 0000000000..80f2045914 --- /dev/null +++ b/test/ttmlir/Conversion/TosaToTTIR/reductions/sum.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s +module attributes {} { + func.func @test_sum(%arg0: tensor<13x21x3xf32>) -> tensor<13x1x3xf32> { + // CHECK: func.func {{.+}} [[IN_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf32>]]{{.*}} -> + %0 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<13x21x3xf32>) -> tensor<13x1x3xf32> + // CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[OUT_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf32>]] + // CHECK: %[[VAL:[0-9]+]] = "ttir.sum"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} ([[IN_SIZE]], [[OUT_SIZE]]) -> [[OUT_SIZE]] + // CHECK: return %[[VAL]] : [[OUT_SIZE]] + return %0 : tensor<13x1x3xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/Decomposition/arange_decomposition.mlir b/test/ttmlir/Dialect/TTIR/Decomposition/arange_decomposition.mlir new file mode 100644 index 0000000000..6f72e56f17 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/Decomposition/arange_decomposition.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttir-decomposition %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttir.arange"[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.transpose"[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + %1 = "ttir.arange"() <{start = 0: si64, end = 32: si64, step = 1: si64, arange_dimension = 1: i64}> : () -> tensor<1x32x128x128xf32> + return %1 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/Decomposition/select_decomposition_tests.mlir b/test/ttmlir/Dialect/TTIR/Decomposition/select_decomposition_tests.mlir new file mode 100644 index 0000000000..8365bbddd3 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/Decomposition/select_decomposition_tests.mlir @@ -0,0 +1,26 @@ +// RUN: ttmlir-opt --ttir-to-ttir-decomposition %s | FileCheck %s + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_identity(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: %{{[0-9]+}} = "ttir.slice" + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } + + func.func @select_multi_slice(%arg0: tensor<4x2x64x128xf32>) -> tensor<4x2x64x32xf32> { + %0 = tensor.empty() : tensor<4x2x64x32xf32> + + // CHECK: %{{[0-9]+}} = "ttir.slice" + // CHECK: %{{[0-9]+}} = "ttir.slice" + // CHECK: %{{[0-9]+}} = "ttir.slice" + // CHECK: %{{[0-9]+}} = "ttir.slice" + // CHECK: %{{[0-9]+}} = "ttir.concat" + %1 = "ttir.select"(%arg0, %0) <{dim = -1: si32, begin = 0: si32, length = 4: si32, stride = 16: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x2x64x128xf32>, tensor<4x2x64x32xf32>) -> tensor<4x2x64x32xf32> + + return %1 : tensor<4x2x64x32xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir new file mode 100644 index 0000000000..522628160c --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir @@ -0,0 +1,194 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for linear operation + +// Verify that the parsing fails if either of operands is a scalar +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_scalar_a(%arg0: tensor, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { + // CHECK: error: 'ttir.linear' op Input A must be at least a 1D tensor + %0 = tensor.empty() : tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_scalar_b(%arg0: tensor<128xbf16>, %arg1: tensor) -> tensor<1xbf16> { + // CHECK: error: 'ttir.linear' op Input B must be at least a 1D tensor + %0 = tensor.empty() : tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_scalar_bias(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor) -> tensor<1xbf16> { + // CHECK: error: 'ttir.linear' op Bias must be at least a 1D tensor + %0 = tensor.empty() : tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} + +// Verifty that the parsing fails if the output is a scalar +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_scalar_output(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor { + // CHECK: error: 'ttir.linear' op Scalar output is not supported, output must be at least a 1D tensor + %0 = tensor.empty() : tensor + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor) -> tensor + return %1 : tensor + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_output_dimension_mismatch(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<2xbf16> { + // CHECK: error: 'ttir.linear' op Scalar output must be a 1D tensor of size 1 + %0 = tensor.empty() : tensor<2xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<2xbf16>) -> tensor<2xbf16> + return %1 : tensor<2xbf16> + } +} + +// Inner dimension mismatch tests +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_inner_dimension_mismatch(%arg0: tensor<128xbf16>, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { +func.func @linear_negative_1d_2d_inner_dimension_mismatch(%arg0: tensor<64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](64) and B[-2](128) must have matching inner dimensions + %0 = tensor.empty() : tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %1 : tensor<64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_2d_1d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64xbf16>) -> tensor<64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %1 : tensor<64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_2d_2d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_inner_dimension_mismatch(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<1x64x128xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + return %1 : tensor<7x64x64xbf16> + } +} + +// Batch dimension mismatch tests +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_same_rank_batch_broadcast_incompatible_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<2x128x64xbf16>) -> tensor<7x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Batch dimensions of input A(7) and B(2) are not broadcast compatible + %0 = tensor.empty() : tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<2x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + return %1 : tensor<7x64x64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_same_rank_batch_broadcast_incompatible_2(%arg0: tensor<2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Batch dimensions of input A(2,7) and B(7,1) are not broadcast compatible + %0 = tensor.empty() : tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x7x64x64xbf16> + return %1 : tensor<7x7x64x64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_different_rank_batch_broadcast_incompatible(%arg0: tensor<12x2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Batch dimensions of input A(12,2,7) and B(7,1) are not broadcast compatible + %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> + return %1 : tensor<12x7x7x64x64xbf16> + } +} + +// Bias shape mismatch tests +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_matmul_bias_broadcast_incompatible(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<2x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: error: 'ttir.linear' op Bias shape(2,64) is not broadcast compatible with the matmul output shape(64,64) + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<2x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_matmul_bias_broadcast_incompatible(%arg0: tensor<3x64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<2x64x64xbf16>) -> tensor<3x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Bias shape(2,64,64) is not broadcast compatible with the matmul output shape(3,64,64) + %0 = tensor.empty() : tensor<3x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64x128xbf16>, tensor<128x64xbf16>, tensor<2x64x64xbf16>, tensor<3x64x64xbf16>) -> tensor<3x64x64xbf16> + return %1 : tensor<3x64x64xbf16> + } +} + +// Output shape mismatch tests +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_2d_2d_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { + // CHECK: error: 'ttir.linear' op Output shape rank(1) must match the expected output shape rank(2) + %0 = tensor.empty() : tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %1 : tensor<64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_2d_2d_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x128xbf16> { + // CHECK: error: 'ttir.linear' op Output shape dimension[1](128) doesn't match the expected output shape dimension[1](64) + %0 = tensor.empty() : tensor<64x128xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTIR/select/select_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/select/select_tests_negative.mlir new file mode 100644 index 0000000000..f505bfcb73 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/select/select_tests_negative.mlir @@ -0,0 +1,116 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_dim(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid dimension}} + %1 = "ttir.select"(%arg0, %0) <{dim = -3: si32, begin = 0: si32, length = 4: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_stride(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid stride.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = 7: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_stride_2(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid stride.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = -1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_begin(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid begin index.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = -3: si32, length = 4: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_begin_2(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid begin index.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 4: si32, length = 4: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_length(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid length.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 5: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_length_2(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid length.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 0: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_length_3(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: {{.*error.*Invalid length.*}} + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 2: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } +} + +// ----- + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_negative_invalid_total_size(%arg0: tensor<4x2x64x48xf32>) -> tensor<4x2x4x48xf32> { + %0 = tensor.empty() : tensor<4x2x4x48xf32> + // CHECK: {{.*error.*Sum of all slices.*}} + %1 = "ttir.select"( %arg0, %0) <{dim = 2: si32, begin = 0: si32, length = 4: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x2x64x48xf32>, tensor<4x2x4x48xf32>) -> tensor<4x2x4x48xf32> + return %1 : tensor<4x2x4x48xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/select/select_tests_positive.mlir b/test/ttmlir/Dialect/TTIR/select/select_tests_positive.mlir new file mode 100644 index 0000000000..b613c85bf8 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/select/select_tests_positive.mlir @@ -0,0 +1,44 @@ +// RUN: ttmlir-opt %s | FileCheck %s + +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @select_identity(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tensor.empty() : tensor<4x4xf32> + // CHECK: %{{[0-9]+}} = "ttir.select" + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + } + + func.func @select_half(%arg0: tensor<4x4xf32>) -> tensor<4x2xf32> { + %0 = tensor.empty() : tensor<4x2xf32> + // CHECK: %{{[0-9]+}} = "ttir.select" + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 2: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x2xf32>) -> tensor<4x2xf32> + return %1 : tensor<4x2xf32> + } + + func.func @select_single(%arg0: tensor<4x4xf32>) -> tensor<4x1xf32> { + %0 = tensor.empty() : tensor<4x1xf32> + // CHECK: %{{[0-9]+}} = "ttir.select" + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 3: si32, length = 1: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x1xf32>) -> tensor<4x1xf32> + return %1 : tensor<4x1xf32> + } + + func.func @select_half_2_no_stride(%arg0: tensor<4x4xf32>) -> tensor<4x2xf32> { + %0 = tensor.empty() : tensor<4x2xf32> + // CHECK: %{{[0-9]+}} = "ttir.select" + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 2: si32, length = 2: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<4x4xf32>, tensor<4x2xf32>) -> tensor<4x2xf32> + return %1 : tensor<4x2xf32> + } + + func.func @select_neg_dim(%arg0: tensor<10x3x128x64xf32>) -> tensor<10x3x8x64xf32> { + %0 = tensor.empty() : tensor<10x3x8x64xf32> + // CHECK: %{{[0-9]+}} = "ttir.select" + %1 = "ttir.select"(%arg0, %0) <{dim = -2: si32, begin = 0: si32, length = 2: si32, stride = 32: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + (tensor<10x3x128x64xf32>, tensor<10x3x8x64xf32>) -> tensor<10x3x8x64xf32> + return %1 : tensor<10x3x8x64xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/split_compound_layout.mlir b/test/ttmlir/Dialect/TTIR/split_compound_layout.mlir index 2335fb0df3..42cab3d1f6 100644 --- a/test/ttmlir/Dialect/TTIR/split_compound_layout.mlir +++ b/test/ttmlir/Dialect/TTIR/split_compound_layout.mlir @@ -3,21 +3,21 @@ #dram = #tt.memory_space #l1_ = #tt.memory_space -// CHECK-DAG: #[[row_major1x1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> -// CHECK-DAG: #[[row_major1x1_T:.*]] = #tt.layout<(d0, d1) -> (d1, d0), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> -// CHECK-DAG: #[[row_major2x2:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>, interleaved> -// CHECK-DAG: #[[tile1x1_f32:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #l1_>, interleaved> -// CHECK-DAG: #[[tile1x1_bf16:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>, interleaved> -// CHECK-DAG: #[[tile1x1_f32_dram:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> -// CHECK-DAG: #[[tile2x2_f32:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32x32, f32>, #l1_>, interleaved> +// CHECK-DAG: #[[row_major1x1:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +// CHECK-DAG: #[[row_major1x1_T:.*]] = #tt.metal_layout<(d0, d1) -> (d1, d0), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +// CHECK-DAG: #[[row_major2x2:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>, interleaved> +// CHECK-DAG: #[[tile1x1_f32:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #l1_>, interleaved> +// CHECK-DAG: #[[tile1x1_bf16:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>, interleaved> +// CHECK-DAG: #[[tile1x1_f32_dram:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> +// CHECK-DAG: #[[tile2x2_f32:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32x32, f32>, #l1_>, interleaved> -#row_major1x1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> -#row_major1x1_T = #tt.layout<(d0, d1) -> (d1, d0), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> -#row_major2x2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>, interleaved> -#tile1x1_f32 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #l1_>, interleaved> -#tile1x1_bf16 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>, interleaved> -#tile1x1_f32_dram = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> -#tile2x2_f32 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32x32, f32>, #l1_>, interleaved> +#row_major1x1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +#row_major1x1_T = #tt.metal_layout<(d0, d1) -> (d1, d0), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +#row_major2x2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>, interleaved> +#tile1x1_f32 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #l1_>, interleaved> +#tile1x1_bf16 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>, interleaved> +#tile1x1_f32_dram = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, interleaved> +#tile2x2_f32 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32x32, f32>, #l1_>, interleaved> func.func @noncompound_linear(%in: tensor<64x128xf32, #row_major1x1>) -> tensor<64x128xf32, #row_major1x1_T> { %out = tensor.empty() : tensor<64x128xf32, #row_major1x1_T> diff --git a/test/ttmlir/Dialect/TTIR/test_allocate.mlir b/test/ttmlir/Dialect/TTIR/test_allocate.mlir index a80a8c1c91..5888cf3f62 100644 --- a/test/ttmlir/Dialect/TTIR/test_allocate.mlir +++ b/test/ttmlir/Dialect/TTIR/test_allocate.mlir @@ -1,7 +1,7 @@ // RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-allocate %s | FileCheck %s #any_device = #tt.operand_constraint #l1_ = #tt.memory_space -#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> +#layout = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> module attributes {} { func.func @forward(%arg0: tensor<64x128xf32, #layout>, %arg1: tensor<64x128xf32, #layout>) -> tensor<64x128xf32, #layout> { // CHECK: %[[C:.*]] = "ttir.alloc"[[C:.*]] diff --git a/test/ttmlir/Dialect/TTIR/test_remove_dead_values_pass.mlir b/test/ttmlir/Dialect/TTIR/test_remove_dead_values_pass.mlir new file mode 100644 index 0000000000..8b6df4d0f2 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/test_remove_dead_values_pass.mlir @@ -0,0 +1,22 @@ +// RUN: ttmlir-opt --remove-dead-values %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %2 = tensor.empty() : tensor<64x128xf32> + // CHECK-NOT: %[[C:.*]] = "ttir.add"[[C:.*]] + %3 = "ttir.add"(%arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %4 = tensor.empty() : tensor<64x128xf32> + // CHECK-NOT: %[[C:.*]] = "ttir.subtract"[[C:.*]] + %5 = "ttir.subtract"(%arg0, %arg1, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %6 = tensor.empty() : tensor<64x128xf32> + // CHECK-NOT: %[[C:.*]] = "ttir.div"[[C:.*]] + %7 = "ttir.div"(%arg0, %arg1, %6) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %8 = tensor.empty() : tensor<64x128xf32> + // CHECK-NOT: %[[C:.*]] = "ttir.eq"[[C:.*]] + %9 = "ttir.eq"(%arg0, %arg1, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir b/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir new file mode 100644 index 0000000000..e1454ad0a0 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir @@ -0,0 +1,28 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for Broadcastable interface + +// CHECK: 'ttir.abs' op Result shape must match operand shapes after broadcasting +#any_device_tile = #tt.operand_constraint +func.func @eltwise_unary(%arg0: tensor<1x64xbf16>) -> tensor<2x64xbf16> { + %0 = tensor.empty() : tensor<2x64xbf16> + %1 = "ttir.abs"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1x64xbf16>, tensor<2x64xbf16>) -> tensor<2x64xbf16> + return %1 : tensor<2x64xbf16> +} + +// ----- +// CHECK: error: 'ttir.add' op Result shape must match operand shapes after broadcasting +#any_device_tile = #tt.operand_constraint +func.func @eltwise_binary(%arg0: tensor<2x3x64xf32>, %arg1: tensor<64xf32>) -> tensor<4x2x3x64xf32> { + %0 = tensor.empty() : tensor<4x2x3x64xf32> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<2x3x64xf32>, tensor<64xf32>, tensor<4x2x3x64xf32>) -> tensor<4x2x3x64xf32> + return %1 : tensor<4x2x3x64xf32> +} + +// ----- +// CHECK: error: 'ttir.where' op Result shape must match operand shapes after broadcasting +#any_device_tile = #tt.operand_constraint +func.func @eltwise_ternary(%arg0: tensor<3x64xf32>, %arg1: tensor<1x3x64xf32>, %arg2: tensor<2x1x64xf32>) -> tensor<1x2x3x64xf32> { + %0 = tensor.empty() : tensor<1x2x3x64xf32> + %1 = "ttir.where"(%arg0, %arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64xf32>, tensor<1x3x64xf32>, tensor<2x1x64xf32>, tensor<1x2x3x64xf32>) -> tensor<1x2x3x64xf32> + return %1 : tensor<1x2x3x64xf32> +} diff --git a/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir b/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir new file mode 100644 index 0000000000..a22dc28370 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir @@ -0,0 +1,37 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for NOperands trait + +// CHECK: error: 'ttir.abs' op expected 2 operands, but found 3 +#any_device_tile = #tt.operand_constraint +func.func @eltwise_unary(%arg0: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttir.abs"(%arg0, %arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> +} + +// ----- +// CHECK: error: 'ttir.add' op expected 3 operands, but found 4 +#any_device_tile = #tt.operand_constraint +func.func @eltwise_binary(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.add"(%arg0, %arg1, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> +} + +// ----- +// CHECK: error: 'ttir.add' op expected 3 operands, but found 2 +#any_device_tile = #tt.operand_constraint +func.func @eltwise_binary(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.add"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> +} + +// ----- +// CHECK: error: 'ttir.where' op expected 4 operands, but found 5 +#any_device_tile = #tt.operand_constraint +func.func @eltwise_ternary(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.where"(%arg0, %arg1, %arg2, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> +} diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/simple_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/simple_workaround.mlir new file mode 100644 index 0000000000..e08ffcd405 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/simple_workaround.mlir @@ -0,0 +1,31 @@ +// RUN: ttmlir-opt --ttnn-workaround %s | FileCheck %s +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #ttnn.buffer_type +#system_memory = #ttnn.buffer_type +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #system_memory>> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #dram>, > +#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, > +module attributes {tt.device = #device} { + func.func @forward(%arg0: tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + // CHECK: %[[DEVICE_OP:.*]] = "ttnn.get_device"[[C:.*]] + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<2x4>>, >}> : (tensor<64x128xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> + // CHECK-NEXT: %[[RM_DEVICE_LAYOUT_OP:.*]] = "ttnn.to_layout"(%arg0, %[[DEVICE_OP]]) + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: -> tensor<64x128xf32, #ttnn_layout1> + %2 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout2> + // CHECK-NEXT: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]]) + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<64x128>>, > + // CHECK-SAME: -> tensor<64x128xf32, #ttnn_layout1> + %3 = "ttnn.abs"(%1, %2) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout2> + // CHECK-NEXT: %[[TO_LAYOUT_LEFT:.*]] = "ttnn.to_layout"(%[[RM_DEVICE_LAYOUT_OP]], %[[DEVICE_OP]]) + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: -> tensor<64x128xf32, #ttnn_layout2> + // CHECK-NEXT: %[[TO_LAYOUT_RIGHT:.*]] = "ttnn.to_layout"(%[[EMPTY_OP]], %[[DEVICE_OP]]) + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: -> tensor<64x128xf32, #ttnn_layout2> + %4 = "ttnn.to_layout"(%3) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<64x128>>>}> : (tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout> + return %4 : tensor<64x128xf32, #ttnn_layout> + } +} diff --git a/test/ttmlir/Dialect/TTNN/Transforms/ttnn_create_input_gens_0.mlir b/test/ttmlir/Dialect/TTNN/Transforms/ttnn_create_input_gens_0.mlir new file mode 100644 index 0000000000..8342c4f5a6 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/Transforms/ttnn_create_input_gens_0.mlir @@ -0,0 +1,36 @@ +// RUN: ttmlir-opt --ttnn-create-input-gens %s | FileCheck %s + +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #ttnn.buffer_type +#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux"}], [{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 98816, erisc_l1_unreserved_base = 102624, dram_unreserved_base = 32, dram_unreserved_end = 1073083040, physical_cores = {worker = [ 1x1, 1x2, 1x3, 1x4, 1x6, 1x7, 1x8, 1x9, 2x1, 2x2, 2x3, 2x4, 2x6, 2x7, 2x8, 2x9, 3x1, 3x2, 3x3, 3x4, 3x6, 3x7, 3x8, 3x9, 4x1, 4x2, 4x3, 4x4, 4x6, 4x7, 4x8, 4x9, 5x1, 5x2, 5x3, 5x4, 5x6, 5x7, 5x8, 5x9, 7x1, 7x2, 7x3, 7x4, 7x6, 7x7, 7x8, 7x9, 8x1, 8x2, 8x3, 8x4, 8x6, 8x7, 8x8, 8x9, 9x1, 9x2, 9x3, 9x4, 9x6, 9x7, 9x8, 9x9] dram = [ 1x0, 1x5, 2x5, 3x5, 5x0, 5x5, 7x0, 7x5, 8x5, 9x5, 11x0, 11x5] eth_inactive = [ 0x1, 0x2, 0x3, 0x4, 0x6, 0x7, 0x8, 0x9, 6x2, 6x3, 6x6, 6x7, 6x8]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> +#system_memory = #ttnn.buffer_type +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<32x32xbf16, #system_memory>> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, bf16>, #dram>, > +module attributes {tt.device = #device, tt.system_desc = #system_desc} { + // CHECK: func.func @add(%arg0: [[TENSOR_A:.*]], %arg1: [[TENSOR_B:.*]]) -> [[TENSOR_OUT:.*]] { + func.func @add(%arg0: tensor<32x32xbf16, #ttnn_layout>, %arg1: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_device"(%arg0, %0) <{memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%1) <{layout = #ttnn.layout}> : (tensor<32x32xbf16, #ttnn_layout1>) -> tensor<32x32xbf16, #ttnn_layout1> + %3 = "ttnn.to_device"(%arg1, %0) <{memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1> + %4 = "ttnn.to_layout"(%3) <{layout = #ttnn.layout}> : (tensor<32x32xbf16, #ttnn_layout1>) -> tensor<32x32xbf16, #ttnn_layout1> + %5 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >, shape = #ttnn.shape<32x32>}> : (!tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1> + %6 = "ttnn.add"(%2, %4, %5) <{operandSegmentSizes = array}> : (tensor<32x32xbf16, #ttnn_layout1>, tensor<32x32xbf16, #ttnn_layout1>, tensor<32x32xbf16, #ttnn_layout1>) -> tensor<32x32xbf16, #ttnn_layout1> + %7 = "ttnn.from_device"(%6) : (tensor<32x32xbf16, #ttnn_layout1>) -> tensor<32x32xbf16, #ttnn_layout> + %8 = "ttnn.to_layout"(%7) <{layout = #ttnn.layout}> : (tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout> + return %8 : tensor<32x32xbf16, #ttnn_layout> + } + +// Confirm that the generator func is generated, and that the tensor attrs match: +// +// CHECK: func.func @createInputsFor_add() -> ([[TENSOR_A]], [[TENSOR_B]]) { +// CHECK: {{.*}} -> [[TENSOR_A]] +// CHECK: {{.*}} -> [[TENSOR_B]] +// CHECK: return %0, %1 : [[TENSOR_A]], [[TENSOR_B]] + +// Confirm that the main func is generated, and that the tensor attrs match: +// +// CHECK: func.func @main() -> i32 { +// CHECK: %0:2 = call @createInputsFor_add() : () -> ([[TENSOR_A]], [[TENSOR_B]]) +// CHECK: %1 = call @add(%0#0, %0#1) : ([[TENSOR_A]], [[TENSOR_B]]) -> [[TENSOR_OUT]] +} diff --git a/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir b/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir new file mode 100644 index 0000000000..dc3f09fbaf --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir @@ -0,0 +1,12 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for matmul operation +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { + // CHECK: error: 'ttir.arange' op Output tensor shape must be 16 at dim 1 (since start=0, end=32, step=2), but got 32 + %1 = "ttir.arange"() <{start = 0: si64, end = 32: si64, step = 2: si64, arange_dimension = 1: i64}> : () -> tensor<1x32x128x128xf32> + %dps = tensor.empty() : tensor<1x32x128x128xf32> + %2 = "ttir.multiply"(%arg0, %1, %dps) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + return %2 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir new file mode 100644 index 0000000000..945b6da5b3 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +// UNSUPPORTED: true +// https://github.com/tenstorrent/tt-mlir/issues/1448 +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.arange"[[C:.*]] + %1 = "ttir.arange"() <{start = 0: si64, end = 32: si64, step = 1: si64, arange_dimension = 1: i64}> : () -> tensor<1x32x128x128xf32> + %dps = tensor.empty() : tensor<1x32x128x128xf32> + %2 = "ttir.multiply"(%arg0, %1, %dps) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + return %2 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir b/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir index f1f5a5965c..cb2a7ad2b3 100644 --- a/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir +++ b/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir @@ -2,7 +2,6 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x1x32x32xbf16>) -> tensor<1x1x32x128xbf16> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<1x1x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] %1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> diff --git a/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir b/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir new file mode 100644 index 0000000000..8f75362a02 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir @@ -0,0 +1,17 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module { + func.func @main(%arg0: tensor<1x256x512xf32>, %arg1: tensor<1024x256x1xf32>, %arg2: tensor<1024xf32>) -> tensor<1x1024x512xf32> { + %0 = tensor.empty() : tensor<1x1024x512xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.reshape"(%{{.*}}) <{shape = [1 : i32, 256 : i32, 512 : i32, 1 : i32]}> : (tensor<[[TENSOR_SHAPE0:[0-9]+x[0-9]+x[0-9]+xf32]], #{{.*}}) -> tensor<[[TENSOR_SHAPE1:[0-9]+x[0-9]+x[0-9]+x1xf32]], #{{.*}}> + // CHECK: [[VAL1:%[0-9]+]] = "ttnn.reshape"(%{{.*}}) <{shape = [1024 : i32, 256 : i32, 1 : i32, 1 : i32]}> : (tensor<[[TENSOR_SHAPE2:[0-9]+x[0-9]+x[0-9]+xf32]], #{{.*}}>) -> tensor<[[TENSOR_SHAPE3:[0-9]+x[0-9]+x[0-9]+x1xf32]], #{{.*}}> + // CHECK: [[VAL2:%[0-9]+]] = "ttnn.transpose"([[VAL0]]) <{dim0 = 1 : si32, dim1 = 2 : si32}> : (tensor<[[TENSOR_SHAPE1]], #{{.*}}>) -> tensor<[[TENSOR_SHAPE4:[0-9]+x[0-9]+x[0-9]+x1xf32]], #{{.*}}> + // CHECK: [[VAL3:%[0-9]+]] = "ttnn.transpose"([[VAL2]]) <{dim0 = 2 : si32, dim1 = 3 : si32}> : (tensor<[[TENSOR_SHAPE4]], #{{.*}}>) -> tensor<[[TENSOR_SHAPE5:[0-9]+x[0-9]+x[0-9]+x[0-9]+xf32]], #{{.*}}> + // CHECK: [[VAL4:%[0-9]+]] = "ttnn.reshape"([[VAL3]]) <{shape = [1 : i32, 1 : i32, 512 : i32, 256 : i32]}> : (tensor<[[TENSOR_SHAPE5]], #{{.*}}>) -> tensor<[[TENSOR_SHAPE6:[0-9]+x[0-9]+x[0-9]+x[0-9]+xf32]], #{{.*}}> + // CHECK: [[VAL5:%[0-9]+]] = "ttnn.conv2d"([[VAL4]], %10, %{{[0-9]+}}, %{{[0-9]+}}) + // CHECK: (tensor<[[TENSOR_SHAPE6]], #{{.*}}>, tensor<1024x256x1x1xf32, #{{.*}}>, tensor<1x1x512x1024xf32, #{{.*}}>, !tt.device<#device>) -> tensor<1x1x512x1024xf32, #{{.*}}> + %1 = "ttir.convolution"(%arg0, %arg1, %0) <{batch_group_count = 1 : i64, convolution_layout = #ttir, feature_group_count = 1 : i64, input_dilation = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array, weight_dilation = array, window_reversal = array, window_strides = array}> : (tensor<1x256x512xf32>, tensor<1024x256x1xf32>, tensor<1x1024x512xf32>) -> tensor<1x1024x512xf32> + // CHECK: return %{{.*}} : tensor<1x1024x512xf32, #ttnn_layout3> + return %1 : tensor<1x1024x512xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/isfinite/simple_isfinite.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/isfinite/simple_isfinite.mlir index e819e68f4b..3089da6692 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/isfinite/simple_isfinite.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/isfinite/simple_isfinite.mlir @@ -1,15 +1,15 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { - func.func @is_finite(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { + func.func @is_finite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty" // CHECK-SAME: [[TENSOR:tensor<64x128xbf16,]] %0 = tensor.empty() : tensor<64x128xbf16> // CHECK: %[[C:.*]] = "ttnn.isfinite" - // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir index ce6887e2a8..1d75b8ee02 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir @@ -4,7 +4,7 @@ #system = #ttnn.buffer_type #ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #system>> #ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8>, memref<8x16xf32, #system>> -#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8>, memref<8x16xf32, #l1>, interleaved> +#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8>, memref<8x16xf32, #l1>, > module attributes {} { func.func @forward(%arg0: tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout1> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/tan/simple_tan.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/tan/simple_tan.mlir new file mode 100644 index 0000000000..8ae9f0bec1 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/tan/simple_tan.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> + %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: %{{[0-9]+}} = "ttnn.tan"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + return %1 : tensor<64x128xf32> + // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> + } +} diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/tanh/simple_tanh.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/tanh/simple_tanh.mlir new file mode 100644 index 0000000000..351476448a --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/tanh/simple_tanh.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> + %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: %{{[0-9]+}} = "ttnn.tanh"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + return %1 : tensor<64x128xf32> + // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> + } +} diff --git a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir index dfbf99008d..6404ee6e94 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir @@ -1,9 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { - func.func @gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { + func.func @gather_0(%operand: tensor<32000x1024xbf16>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<1x32x1024xf32> + %0 = tensor.empty() : tensor<1x32x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] %1 = "ttir.gather"(%operand, %start_indices, %0) { offset_dims = array, @@ -15,13 +15,13 @@ module attributes {} { slice_sizes = array, indices_are_sorted = false, operand_constraints = [#any_device, #any_device, #any_device] - } : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32> - return %1 : tensor<1x32x1024xf32> + } : (tensor<32000x1024xbf16>, tensor<1x32xi32>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16> + return %1 : tensor<1x32x1024xbf16> } - func.func @gather_1(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xf32> { + func.func @gather_1(%operand: tensor<448x384xbf16>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<1x2x384xf32> + %0 = tensor.empty() : tensor<1x2x384xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] %1 = "ttir.gather"(%operand, %start_indices, %0) <{ offset_dims = array, @@ -33,13 +33,13 @@ module attributes {} { slice_sizes = array, indices_are_sorted = false, operand_constraints = [#any_device, #any_device, #any_device] - }> : (tensor<448x384xf32>, tensor<1x2x1xi32>, tensor<1x2x384xf32>) -> tensor<1x2x384xf32> - return %1 : tensor<1x2x384xf32> + }> : (tensor<448x384xbf16>, tensor<1x2x1xi32>, tensor<1x2x384xbf16>) -> tensor<1x2x384xbf16> + return %1 : tensor<1x2x384xbf16> } - func.func @gather_2(%operand: tensor<51864x384xf32>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xf32> { + func.func @gather_2(%operand: tensor<51864x384xbf16>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<1x2x384xf32> + %0 = tensor.empty() : tensor<1x2x384xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] %1 = "ttir.gather"(%operand, %start_indices, %0) <{ offset_dims = array, @@ -51,7 +51,7 @@ module attributes {} { slice_sizes = array, indices_are_sorted = false, operand_constraints = [#any_device, #any_device, #any_device] - }> : (tensor<51864x384xf32>, tensor<1x2xi32>, tensor<1x2x384xf32>) -> tensor<1x2x384xf32> - return %1 : tensor<1x2x384xf32> + }> : (tensor<51864x384xbf16>, tensor<1x2xi32>, tensor<1x2x384xbf16>) -> tensor<1x2x384xbf16> + return %1 : tensor<1x2x384xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding_negative.mlir b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding_negative.mlir index 2a06bf92b6..44ffea73ef 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding_negative.mlir @@ -110,3 +110,25 @@ module attributes {} { return %1 : tensor<1x2x384xf32> } } + +// Verify that the parsing fails for data type other than bfloat16. +// ----- +#any_device = #tt.operand_constraint +module attributes {} { + func.func @gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { + %0 = tensor.empty() : tensor<1x32x1024xf32> + // CHECK: error: failed to legalize operation 'ttir.gather' that was explicitly marked illegal + %1 = "ttir.gather"(%operand, %start_indices, %0) { + offset_dims = array, + collapsed_slice_dims = array, + operand_batching_dims = array, + start_indices_batching_dims = array, + start_index_map = array, + index_vector_dim = 1 : si64, + slice_sizes = array, + indices_are_sorted = false, + operand_constraints = [#any_device, #any_device, #any_device] + } : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32> + return %1 : tensor<1x32x1024xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir new file mode 100644 index 0000000000..0e248623da --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir @@ -0,0 +1,216 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_1d_1d(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<1xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<1xbf16 + %0 = tensor.empty() : tensor<1xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<1xbf16 + // CHECK-SAME: tensor<1xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } + + func.func @linear_1d_1d_bias(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor<1xbf16>) -> tensor<1xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<1xbf16 + %0 = tensor.empty() : tensor<1xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<1xbf16 + // CHECK-SAME: tensor<1xbf16 + // CHECK-SAME: tensor<1xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } + + func.func @linear_1d_1d_bias_broadcast(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor<128xbf16>) -> tensor<128xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<128xbf16 + %0 = tensor.empty() : tensor<128xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<128xbf16>, tensor<128xbf16>) -> tensor<128xbf16> + return %1 : tensor<128xbf16> + } + + func.func @linear_2d_1d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128xbf16>) -> tensor<64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64xbf16 + %0 = tensor.empty() : tensor<64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<64xbf16 + // CHECK-SAME: tensor<64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %1 : tensor<64xbf16> + } + + func.func @linear_2d_2d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @linear_2d_2d_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @linear_1d_nd(%arg0: tensor<128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x64xbf16 + %0 = tensor.empty() : tensor<12x7x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<12x7x128x64xbf16 + // CHECK-SAME: tensor<12x7x64xbf16 + // CHECK-SAME: tensor<12x7x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64xbf16>) -> tensor<12x7x64xbf16> + return %1 : tensor<12x7x64xbf16> + } + + func.func @linear_nd_1d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64xbf16>) -> tensor<12x7x128xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x128xbf16 + %0 = tensor.empty() : tensor<12x7x128xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<12x7x128x64xbf16 + // CHECK-SAME: tensor<64xbf16 + // CHECK-SAME: tensor<12x7x128xbf16 + // CHECK-SAME: tensor<12x7x128xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x7x128x64xbf16>, tensor<64xbf16>, tensor<12x7x128xbf16>) -> tensor<12x7x128xbf16> + return %1 : tensor<12x7x128xbf16> + } + + func.func @linear_2d_nd(%arg0: tensor<64x128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x64x64xbf16 + %0 = tensor.empty() : tensor<12x7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<12x7x128x64xbf16 + // CHECK-SAME: tensor<12x7x64x64xbf16 + // CHECK-SAME: tensor<12x7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64x64xbf16>) -> tensor<12x7x64x64xbf16> + return %1 : tensor<12x7x64x64xbf16> + } + + func.func @linear_nd_2d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<12x7x128x128xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x128x128xbf16 + %0 = tensor.empty() : tensor<12x7x128x128xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<12x7x128x64xbf16 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<12x7x128x128xbf16 + // CHECK-SAME: tensor<12x7x128x128xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x7x128x64xbf16>, tensor<64x128xbf16>, tensor<12x7x128x128xbf16>) -> tensor<12x7x128x128xbf16> + return %1 : tensor<12x7x128x128xbf16> + } + + // linear nd - nd tests + func.func @linear_nd_nd_same_rank_same_dims(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<7x128x64xbf16>) -> tensor<7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<7x64x64xbf16 + %0 = tensor.empty() : tensor<7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<7x64x128xbf16 + // CHECK-SAME: tensor<7x128x64xbf16 + // CHECK-SAME: tensor<7x64x64xbf16 + // CHECK-SAME: tensor<7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<7x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + return %1 : tensor<7x64x64xbf16> + } + + func.func @linear_nd_nd_same_rank_broadcastable_dims_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x128x64xbf16>) -> tensor<7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<7x64x64xbf16 + %0 = tensor.empty() : tensor<7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<7x64x128xbf16 + // CHECK-SAME: tensor<1x128x64xbf16 + // CHECK-SAME: tensor<7x64x64xbf16 + // CHECK-SAME: tensor<7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + return %1 : tensor<7x64x64xbf16> + } + + func.func @linear_nd_nd_same_rank_broadcastable_dims_2(%arg0: tensor<1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<7x7x64x64xbf16 + %0 = tensor.empty() : tensor<7x7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<1x7x64x128xbf16 + // CHECK-SAME: tensor<7x1x128x64xbf16 + // CHECK-SAME: tensor<7x7x64x64xbf16 + // CHECK-SAME: tensor<7x7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x7x64x64xbf16>) -> tensor<7x7x64x64xbf16> + return %1 : tensor<7x7x64x64xbf16> + } + + func.func @linear_nd_nd_different_rank_broadcastable_dims_2(%arg0: tensor<12x1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x7x64x64xbf16 + %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<12x1x7x64x128xbf16 + // CHECK-SAME: tensor<7x1x128x64xbf16 + // CHECK-SAME: tensor<12x7x7x64x64xbf16 + // CHECK-SAME: tensor<12x7x7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> + return %1 : tensor<12x7x7x64x64xbf16> + } + + func.func @linear_nd_nd_bias_broadcast_bias(%arg0: tensor<14x7x32x32xbf16>, %arg1:tensor<14x1x32x64xbf16>, %bias: tensor<64xbf16>) -> tensor<14x7x32x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<14x7x32x64xbf16 + %0 = tensor.empty() : tensor<14x7x32x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<14x7x32x32xbf16 + // CHECK-SAME: tensor<14x1x32x64xbf16 + // CHECK-SAME: tensor<64xbf16 + // CHECK-SAME: tensor<14x7x32x64xbf16 + // CHECK-SAME: tensor<14x7x32x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<14x7x32x32xbf16>, tensor<14x1x32x64xbf16>, tensor<64xbf16>, tensor<14x7x32x64xbf16>) -> tensor<14x7x32x64xbf16> + return %1 : tensor<14x7x32x64xbf16> + } + + func.func @linear_nd_nd_bias_broadcast_matmul(%arg0: tensor<3x64x128xbf16>, %arg1: tensor<4x3x128x32xbf16>, %bias: tensor<14x4x3x64x32xbf16>) -> tensor<14x4x3x64x32xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<14x4x3x64x32xbf16 + %0 = tensor.empty() : tensor<14x4x3x64x32xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<3x64x128xbf16 + // CHECK-SAME: tensor<4x3x128x32xbf16 + // CHECK-SAME: tensor<14x4x3x64x32xbf16 + // CHECK-SAME: tensor<14x4x3x64x32xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64x128xbf16>, tensor<4x3x128x32xbf16>, tensor<14x4x3x64x32xbf16>, tensor<14x4x3x64x32xbf16>) -> tensor<14x4x3x64x32xbf16> + return %1 : tensor<14x4x3x64x32xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir b/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir new file mode 100644 index 0000000000..56728eb52b --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir @@ -0,0 +1,31 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint + +module { + func.func @simple_linear_without_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @simple_linear_with_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir b/test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir index 63af0b5b49..f82ed85752 100644 --- a/test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir +++ b/test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir @@ -1,6 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device_tile = #tt.operand_constraint -// CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, interleaved> +// CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, > module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir index eb3bc04956..4a4575f8d1 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir @@ -4,13 +4,13 @@ // CHECK-DAG: #[[LOC_MATMUL_IN0:.*]] = loc("matmul_1_in_0_layout"(#loc3)) // CHECK-DAG: #[[LOC_MATMUL_IN1:.*]] = loc("matmul_1_in_1_layout"(#loc3)) // CHECK-DAG: #[[LOC_MATMUL:.*]] = loc("matmul_1"(#loc3)) -// CHECK-DAG: #[[IN_1_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<4x3x!tt.tile<32x32, bf16>, #l1_>, interleaved> +// CHECK-DAG: #[[IN_1_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<4x3x!tt.tile<32x32, bf16>, #l1_>, > module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> loc(#loc2) // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} loc(#[[LOC_MATMUL_IN0]]) - // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} <{memory_config = #ttnn.memory_config<, #l1_, <<4x3>>>}> : {{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]]) + // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} <{memory_config = #ttnn.memory_config<#l1_, <<4x3>>, >}> : {{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]]) // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} loc(#[[LOC_MATMUL]]) %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> loc(#loc2) return %1 : tensor<64x96xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir index 76ebd31cfa..ec03a6ad59 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir @@ -4,8 +4,8 @@ module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> tensor<1x32x32xf32> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, width_sharded> - // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #dram>, interleaved> + // CHECK-DAG: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #dram>, > %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) diff --git a/test/ttmlir/Silicon/TTNN/optimizer/all_l1_interleaved_policy.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/all_l1_interleaved_policy.mlir similarity index 82% rename from test/ttmlir/Silicon/TTNN/optimizer/all_l1_interleaved_policy.mlir rename to test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/all_l1_interleaved_policy.mlir index 1d278958d3..5c34fe8548 100644 --- a/test/ttmlir/Silicon/TTNN/optimizer/all_l1_interleaved_policy.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/all_l1_interleaved_policy.mlir @@ -1,12 +1,9 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>, %arg2: tensor<64x96xbf16>, %arg3: tensor<96x32xbf16>, %arg4: tensor<64x32xbf16>) -> tensor<64x32xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK: #[[LAYOUT_L1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #l1_>, interleaved> - // CHECK: #[[LAYOUT_DRAM:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #dram>, interleaved> + // CHECK: #[[LAYOUT_L1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #l1_>, > %0 = tensor.empty() : tensor<64x96xbf16> // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_L1]]> %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> @@ -23,7 +20,7 @@ module attributes {} { // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_L1]]> %9 = "ttir.add"(%7, %arg4, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> %10 = tensor.empty() : tensor<64x32xbf16> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_DRAM]]> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_L1]]> %11 = "ttir.relu"(%9, %10) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> return %11 : tensor<64x32xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/fork_join.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/fork_join.mlir new file mode 100644 index 0000000000..67c480d8c9 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/fork_join.mlir @@ -0,0 +1,44 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s +// +// A +// | +// B +// / \ +// C D +// | | +// | E +// \ / +// F +// | +// G +// +// This tests two things: +// 1. Output of op B (fork op) should be in DRAM. +// 2. Even though both precedence [C, E] and [E, C] for op F are legal, +// the optimizer should choose the one with lower requiredL1Usage. In +// this case, [E, C] should be chosen. +// +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<64x64xbf16>, %arg1: tensor<64x32xbf16>) -> tensor<64x32xbf16> { + // CHECK: #[[L1_:.*]] = #ttnn.buffer_type + // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, bf16>, #dram>, > + // CHECK: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, bf16>, #l1_>, > + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_3]]> + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %2 = tensor.empty() : tensor<64x64xbf16> + %3 = "ttir.relu"(%1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %4 = tensor.empty() : tensor<64x32xbf16> + %5 = "ttir.matmul"(%1, %arg1, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x64xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> + %6 = tensor.empty() : tensor<64x32xbf16> + %7 = "ttir.relu"(%5, %6) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> + %8 = tensor.empty() : tensor<64x32xbf16> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_5]]> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> + %9 = "ttir.matmul"(%3, %7, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x64xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> + return %9 : tensor<64x32xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/optimizer/mnist_l1_interleaved.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/mnist_l1_interleaved.mlir similarity index 85% rename from test/ttmlir/Silicon/TTNN/optimizer/mnist_l1_interleaved.mlir rename to test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/mnist_l1_interleaved.mlir index 4bd0867a23..f45c11c624 100644 --- a/test/ttmlir/Silicon/TTNN/optimizer/mnist_l1_interleaved.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/mnist_l1_interleaved.mlir @@ -1,12 +1,9 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s #any_device = #tt.operand_constraint #loc = loc("MNISTLinear":4294967295:0) module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { - // CHECK: #[[LAYOUT_L1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #l1_>, interleaved> - // CHECK: #[[LAYOUT_DRAM:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #dram>, interleaved> + // CHECK: #[[LAYOUT_L1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #l1_>, > %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_L1]]> %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) @@ -23,7 +20,7 @@ module @"tt-forge-graph" attributes {} { // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_L1]]> %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) %10 = tensor.empty() : tensor<1x10xf32> loc(#loc13) - // CHECK: %{{.*}} = "ttnn.softmax"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_DRAM]]> + // CHECK: %{{.*}} = "ttnn.softmax"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_L1]]> %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) return %11 : tensor<1x10xf32> loc(#loc7) } loc(#loc) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_ABC_l1_None.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_ABC_l1_None.mlir new file mode 100644 index 0000000000..e5a4f3fa66 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_ABC_l1_None.mlir @@ -0,0 +1,28 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s +// +// A B +// \ / +// C +// | +// D +// +// (A > L1) AND (B > L1) AND (C > L1) +// => +// DRAM: ABC; L1: None +// +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<8192x8192xbf16>, %arg1: tensor<8192x8192xbf16>, %arg2: tensor<8192x8192xbf16>, %arg3: tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> { + // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<{{.*}}>, #dram>, > + %0 = tensor.empty() : tensor<8192x8192xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x8192xbf16, #[[LAYOUT_2]]> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> + %2 = tensor.empty() : tensor<8192x8192xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x8192xbf16, #[[LAYOUT_2]]> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> + %4 = tensor.empty() : tensor<8192x8192xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<8192x8192xbf16, #[[LAYOUT_2]]> + %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> + return %5 : tensor<8192x8192xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir new file mode 100644 index 0000000000..ceca628400 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir @@ -0,0 +1,31 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s +// +// A B +// \ / +// C +// | +// D +// +// (A + C > L1) AND (B + C > L1) AND (A + B > L1) AND (A < C) AND (B < C) AND (C <= L1) +// => +// DRAM: AB; L1: C +// +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<5120x4096xbf16>, %arg1: tensor<5120x4096xbf16>, %arg2: tensor<4096x5120xbf16>, %arg3: tensor<4096x5120xbf16>) -> tensor<5120x5120xbf16> { + // CHECK: #[[L1_:.*]] = #ttnn.buffer_type + // CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x16x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > + %0 = tensor.empty() : tensor<5120x4096xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x4096xbf16, #[[LAYOUT_4]]> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x4096xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> + %2 = tensor.empty() : tensor<4096x5120xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_6]]> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> + %4 = tensor.empty() : tensor<5120x5120xbf16> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<5120x5120xbf16, #[[LAYOUT_7]]> + %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x4096xbf16>, tensor<4096x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> + return %5 : tensor<5120x5120xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir new file mode 100644 index 0000000000..74675e4e0b --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir @@ -0,0 +1,30 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s +// +// A B +// \ / +// C +// | +// D +// +// (A + C > L1) AND (B + C > L1) AND (A + B > L1) AND (A < B) AND (C < B) AND (B <= L1) +// => +// DRAM: AC; L1: B +// +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<4096x5120xbf16>, %arg1: tensor<4096x5120xbf16>, %arg2: tensor<5120x5120xbf16>, %arg3: tensor<5120x5120xbf16>) -> tensor<4096x5120xbf16> { + // CHECK: #[[L1_:.*]] = #ttnn.buffer_type + // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > + %0 = tensor.empty() : tensor<4096x5120xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_3]]> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> + %2 = tensor.empty() : tensor<5120x5120xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x5120xbf16, #[[LAYOUT_5]]> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> + %4 = tensor.empty() : tensor<4096x5120xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_3]]> + %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<4096x5120xbf16>, tensor<5120x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> + return %5 : tensor<4096x5120xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir new file mode 100644 index 0000000000..c3cd2740bc --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir @@ -0,0 +1,30 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s +// +// A B +// \ / +// C +// | +// D +// +// (A + B + C > L1) AND (A + C < B + C) AND (A + B < B + C) AND (B + C <= L1) +// => +// DRAM: A; L1: BC +// +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<2048x2048xbf16>, %arg1: tensor<2048x2048xbf16>, %arg2: tensor<2048x8192xbf16>, %arg3: tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> { + // CHECK: #[[L1_:.*]] = #ttnn.buffer_type + // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x32x!tt.tile<32x32, bf16>, #l1_>, > + %0 = tensor.empty() : tensor<2048x2048xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_3]]> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x2048xbf16>, tensor<2048x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> + %2 = tensor.empty() : tensor<2048x8192xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x8192xbf16, #[[LAYOUT_5]]> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x8192xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> + %4 = tensor.empty() : tensor<2048x8192xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<2048x8192xbf16, #[[LAYOUT_5]]> + %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x2048xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> + return %5 : tensor<2048x8192xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir new file mode 100644 index 0000000000..c9cd33f1c9 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir @@ -0,0 +1,30 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s +// +// A B +// \ / +// C +// | +// D +// +// (A + C > L1) AND (B + C > L1) AND (A + B > L1) AND (B < A) AND (C < A) AND (A <= L1) +// => +// DRAM: BC; L1: A +// +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<5120x5120xbf16>, %arg1: tensor<5120x5120xbf16>, %arg2: tensor<5120x4096xbf16>, %arg3: tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> { + // CHECK: #[[L1_:.*]] = #ttnn.buffer_type + // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<20x16x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > + %0 = tensor.empty() : tensor<5120x5120xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x5120xbf16, #[[LAYOUT_5]]> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> + %2 = tensor.empty() : tensor<5120x4096xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x4096xbf16, #[[LAYOUT_3]]> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x4096xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> + %4 = tensor.empty() : tensor<5120x4096xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<5120x4096xbf16, #[[LAYOUT_3]]> + %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x5120xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> + return %5 : tensor<5120x4096xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir new file mode 100644 index 0000000000..760ea2b8a5 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir @@ -0,0 +1,30 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s +// +// A B +// \ / +// C +// | +// D +// +// (A + B + C > L1) AND (B + C < A + C) AND (A + B < A + C) AND (A + C <= L1) +// => +// DRAM: B; L1: AC +// +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<8192x2048xbf16>, %arg1: tensor<8192x2048xbf16>, %arg2: tensor<2048x2048xbf16>, %arg3: tensor<2048x2048xbf16>) -> tensor<8192x2048xbf16> { + // CHECK: #[[L1_:.*]] = #ttnn.buffer_type + // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x8x!tt.tile<32x32, bf16>, #l1_>, > + %0 = tensor.empty() : tensor<8192x2048xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_5]]> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> + %2 = tensor.empty() : tensor<2048x2048xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_3]]> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x2048xbf16>, tensor<2048x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> + %4 = tensor.empty() : tensor<8192x2048xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_5]]> + %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x2048xbf16>, tensor<2048x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> + return %5 : tensor<8192x2048xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir new file mode 100644 index 0000000000..5d95a6204a --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir @@ -0,0 +1,31 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s +// +// A B +// \ / +// C +// | +// D +// +// (A + B + C > L1) AND (A + C < A + B) AND (B + C < A + B) AND (A + B <= L1) +// => +// DRAM: C; L1: AB +// +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<2048x8192xbf16>, %arg1: tensor<2048x8192xbf16>, %arg2: tensor<8192x2048xbf16>, %arg3: tensor<8192x2048xbf16>) -> tensor<2048x2048xbf16> { + // CHECK: #[[L1_:.*]] = #ttnn.buffer_type + // CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x32x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x8x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > + %0 = tensor.empty() : tensor<2048x8192xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x8192xbf16, #[[LAYOUT_4]]> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x8192xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> + %2 = tensor.empty() : tensor<8192x2048xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_6]]> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> + %4 = tensor.empty() : tensor<2048x2048xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_7]]> + %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x8192xbf16>, tensor<8192x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> + return %5 : tensor<2048x2048xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_None_l1_ABC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_None_l1_ABC.mlir new file mode 100644 index 0000000000..75b876dbf3 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_None_l1_ABC.mlir @@ -0,0 +1,29 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s +// +// A B +// \ / +// C +// | +// D +// +// (A + B + C <= L1) +// => +// DRAM: None; L1: ABC +// +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>, %arg2: tensor<32x32xbf16>, %arg3: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + // CHECK: #[[L1_:.*]] = #ttnn.buffer_type + // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, bf16>, #l1_>, > + %0 = tensor.empty() : tensor<32x32xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x32xbf16, #[[LAYOUT_2]]> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %2 = tensor.empty() : tensor<32x32xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x32xbf16, #[[LAYOUT_2]]> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %4 = tensor.empty() : tensor<32x32xbf16> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x32xbf16, #[[LAYOUT_2]]> + %5 = "ttir.add"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %5 : tensor<32x32xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/single_op.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/single_op.mlir new file mode 100644 index 0000000000..4820799936 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/single_op.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s +// UNSUPPORTED: true +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> { + %0 = tensor.empty() : tensor<5120x5120xbf16> + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> + return %1 : tensor<5120x5120xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir b/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir index 3c89324cb6..8e25f97ca0 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir @@ -3,7 +3,7 @@ #loc = loc("test_ops.py:17_0_0":0:0) module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { - // CHECK: #[[LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, f32>, #dram>, interleaved> + // CHECK: #[[LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, f32>, #dram>, > %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT]]> %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir index 0996b86fb1..79bbae2753 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir @@ -5,9 +5,9 @@ module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK: #[[LAYOUT_0:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #system_memory>> - // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <4x4>, memref<8x8xbf16, #dram>, interleaved> - // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <4x4>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, interleaved> - // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<{{.*}} #dram>, interleaved> + // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <4x4>, memref<8x8xbf16, #dram>, > + // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <4x4>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > + // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<{{.*}} #dram>, > %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/test_override_reshard_edges.mlir b/test/ttmlir/Dialect/TTNN/optimizer/test_override_reshard_edges.mlir index ed83ba0aff..08e6da1165 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/test_override_reshard_edges.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/test_override_reshard_edges.mlir @@ -3,24 +3,25 @@ #dram = #ttnn.buffer_type #system_memory = #ttnn.buffer_type #ttnn_layout = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #system_memory>> -#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #dram>, interleaved> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #dram>, > module attributes {tt.device = #device} { func.func @main(%arg0: tensor<1x32x32xf32, #ttnn_layout>, %arg1: tensor<1x32x32xf32, #ttnn_layout>, %arg2: tensor<1x32x32xf32, #ttnn_layout>) -> tensor<1x32x32xf32, #ttnn_layout> { - // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #dram>, interleaved> - // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, width_sharded> + // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #dram>, > + // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > + // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, f32>, #dram>, > %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> - %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>}> : (tensor<1x32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> - %2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>}> : (tensor<1x32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>, shape = #ttnn.shape<1x32x32>}> : (!tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> loc(#loc1) - // CHECK: %[[IDX:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<32x32>>, >}> : (tensor<1x32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> + %2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<32x32>>, >}> : (tensor<1x32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<32x32>>, >, shape = #ttnn.shape<1x32x32>}> : (!tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> loc(#loc1) + // CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> %4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32, #ttnn_layout1>, tensor<1x32x32xf32, #ttnn_layout1>, tensor<1x32x32xf32, #ttnn_layout1>) -> tensor<1x32x32xf32, #ttnn_layout1> loc(#loc1) - %5 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>}> : (tensor<1x32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> - %6 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>, shape = #ttnn.shape<1x32x32>}> : (!tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> loc(#loc2) - // CHECK: %{{.*}} = "ttnn.to_layout"(%[[IDX]], %0) {{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> + %5 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<32x32>>, >}> : (tensor<1x32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> + %6 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<32x32>>, >, shape = #ttnn.shape<1x32x32>}> : (!tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> loc(#loc2) + // CHECK: %{{.*}} = "ttnn.to_layout"(%[[C]], %0) {{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> %7 = "ttnn.add"(%4, %6, %6) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32, #ttnn_layout1>, tensor<1x32x32xf32, #ttnn_layout1>, tensor<1x32x32xf32, #ttnn_layout1>) -> tensor<1x32x32xf32, #ttnn_layout1> loc(#loc2) - %8 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>, shape = #ttnn.shape<1x32x32>}> : (!tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> loc(#loc3) + %8 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<32x32>>, >, shape = #ttnn.shape<1x32x32>}> : (!tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> loc(#loc3) %9 = "ttnn.relu"(%7, %8) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32, #ttnn_layout1>, tensor<1x32x32xf32, #ttnn_layout1>) -> tensor<1x32x32xf32, #ttnn_layout1> loc(#loc3) - %10 = "ttnn.to_layout"(%9) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<32x32>>>}> : (tensor<1x32x32xf32, #ttnn_layout1>) -> tensor<1x32x32xf32, #ttnn_layout> + %10 = "ttnn.to_layout"(%9) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<32x32>>>}> : (tensor<1x32x32xf32, #ttnn_layout1>) -> tensor<1x32x32xf32, #ttnn_layout> return %10 : tensor<1x32x32xf32, #ttnn_layout> } } diff --git a/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir new file mode 100644 index 0000000000..c7f4442f0b --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s +#any_device_tile = #tt.operand_constraint +// Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes. +module @reshape_test { + func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1xi32> + %1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK-NOT: %[[C:.*]] = "ttnn.reshape"[C:.*]] + // CHECK: return %arg0 : tensor<1xi32, #{{.*}}> + return %1 : tensor<1xi32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/simple_constant.mlir b/test/ttmlir/Dialect/TTNN/simple_constant.mlir index 88df7aad24..017a1baf0c 100644 --- a/test/ttmlir/Dialect/TTNN/simple_constant.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_constant.mlir @@ -3,31 +3,31 @@ module attributes {} { func.func @test_empty_int8() -> tensor<64x128xi8> { %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi8>}> : () -> tensor<64x128xi8> - // CHECK: %{{[0-9]+}} = "ttnn.empty" + // CHECK: %{{[0-9]+}} = "ttnn.full" return %0 : tensor<64x128xi8> } func.func @test_empty_int16() -> tensor<64x128xi16> { %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi16>}> : () -> tensor<64x128xi16> - // CHECK: %{{[0-9]+}} = "ttnn.empty" + // CHECK: %{{[0-9]+}} = "ttnn.full" return %0 : tensor<64x128xi16> } func.func @test_empty_int() -> tensor<64x128xi32> { %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi32>}> : () -> tensor<64x128xi32> - // CHECK: %{{[0-9]+}} = "ttnn.empty" + // CHECK: %{{[0-9]+}} = "ttnn.full" return %0 : tensor<64x128xi32> } func.func @test_empty_bfloat16() -> tensor<64x128xbf16> { %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xbf16>}> : () -> tensor<64x128xbf16> - // CHECK: %{{[0-9]+}} = "ttnn.empty" + // CHECK: %{{[0-9]+}} = "ttnn.full" return %0 : tensor<64x128xbf16> } func.func @test_empty_float() -> tensor<64x128xf32> { %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.empty" + // CHECK: %{{[0-9]+}} = "ttnn.full" return %0 : tensor<64x128xf32> } diff --git a/test/ttmlir/Dialect/TTNN/simple_scatter.mlir b/test/ttmlir/Dialect/TTNN/simple_scatter.mlir new file mode 100644 index 0000000000..5991efeabe --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_scatter.mlir @@ -0,0 +1,16 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { + %0 = tensor.empty() : tensor<1x3x320x320xf32> + %1 = tensor.empty() : tensor<1x1xi32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, shape = #ttnn.shape<[[TENSOR_SHAPE0:[0-9]+x[0-9]+x[0-9]+x[0-9]+]]>}> : (!tt.device<#device>) -> tensor<[[TENSOR_SHAPE1:[0-9]+x[0-9]+x[0-9]+x[0-9]+xf[0-9]+]], {{.*}}> + %2 = "ttir.scatter"(%arg0, %1, %arg1, %0) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile], scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array}> ({ + ^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>): + "ttir.yield"(%arg4) : (tensor<1xf32>) -> () + }) : (tensor<1x3x320x320xf32>, tensor<1x1xi32>, tensor<1x3x32x32xf32>, tensor<1x3x320x320xf32>) -> tensor<1x3x320x320xf32> + // CHECK: {{[0-9]+}} = "ttnn.scatter"(%4, %2, %5) <{operandSegmentSizes = array}> : (tensor<1x3x32x32xf32, {{.*}}>, tensor<[[TENSOR_SHAPE1]], {{.*}}>, tensor<[[TENSOR_SHAPE1]], {{.*}}>) -> tensor<[[TENSOR_SHAPE1]], {{.*}}> + return %2 : tensor<1x3x320x320xf32> + // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE1]], {{.*}}> + } +} diff --git a/test/ttmlir/Dialect/TTNN/test_remove_dead_values_pass.mlir b/test/ttmlir/Dialect/TTNN/test_remove_dead_values_pass.mlir new file mode 100644 index 0000000000..ea59aae1c0 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/test_remove_dead_values_pass.mlir @@ -0,0 +1,77 @@ +// RUN: ttmlir-opt --remove-dead-values %s | FileCheck %s +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #ttnn.buffer_type +#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> +#system_memory = #ttnn.buffer_type +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #system_memory>> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, > +#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #dram>, > +module attributes {tt.device = #device, tt.system_desc = #system_desc} { + func.func @forward(%arg0: tensor<64x128xf32, #ttnn_layout>, %arg1: tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout1> + %2 = "ttnn.to_device"(%1, %0) <{memory_config = #ttnn.memory_config<#dram, <<2x4>>, >}> : (tensor<64x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> + "ttnn.deallocate"(%1) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %3 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout1> + %4 = "ttnn.to_device"(%3, %0) <{memory_config = #ttnn.memory_config<#dram, <<2x4>>, >}> : (tensor<64x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> + "ttnn.deallocate"(%3) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %5 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout2> + // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] + %6 = "ttnn.multiply"(%2, %4, %5) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout2> + "ttnn.deallocate"(%4) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + "ttnn.deallocate"(%2) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %7 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout1> + %8 = "ttnn.to_device"(%7, %0) <{memory_config = #ttnn.memory_config<#dram, <<2x4>>, >}> : (tensor<64x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> + "ttnn.deallocate"(%7) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %9 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout1> + %10 = "ttnn.to_device"(%9, %0) <{memory_config = #ttnn.memory_config<#dram, <<2x4>>, >}> : (tensor<64x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> + "ttnn.deallocate"(%9) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %11 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout2> + // CHECK-NOT: %[[C:.*]] = "ttnn.add"[[C:.*]] + %12 = "ttnn.add"(%8, %10, %11) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout2> + "ttnn.deallocate"(%11) <{force = false}> : (tensor<64x128xf32, #ttnn_layout2>) -> () + "ttnn.deallocate"(%10) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + "ttnn.deallocate"(%8) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %13 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout1> + %14 = "ttnn.to_device"(%13, %0) <{memory_config = #ttnn.memory_config<#dram, <<2x4>>, >}> : (tensor<64x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> + "ttnn.deallocate"(%13) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %15 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout1> + %16 = "ttnn.to_device"(%15, %0) <{memory_config = #ttnn.memory_config<#dram, <<2x4>>, >}> : (tensor<64x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> + "ttnn.deallocate"(%15) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %17 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout2> + // CHECK-NOT: %[[C:.*]] = "ttnn.subtract"[[C:.*]] + %18 = "ttnn.subtract"(%14, %16, %17) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout2> + "ttnn.deallocate"(%17) <{force = false}> : (tensor<64x128xf32, #ttnn_layout2>) -> () + "ttnn.deallocate"(%16) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + "ttnn.deallocate"(%14) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %19 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout1> + %20 = "ttnn.to_device"(%19, %0) <{memory_config = #ttnn.memory_config<#dram, <<2x4>>, >}> : (tensor<64x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> + "ttnn.deallocate"(%19) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %21 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout1> + %22 = "ttnn.to_device"(%21, %0) <{memory_config = #ttnn.memory_config<#dram, <<2x4>>, >}> : (tensor<64x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> + "ttnn.deallocate"(%21) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %23 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout2> + // CHECK-NOT: %[[C:.*]] = "ttnn.div"[[C:.*]] + %24 = "ttnn.div"(%20, %22, %23) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout2> + "ttnn.deallocate"(%23) <{force = false}> : (tensor<64x128xf32, #ttnn_layout2>) -> () + "ttnn.deallocate"(%22) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + "ttnn.deallocate"(%20) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %25 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout1> + %26 = "ttnn.to_device"(%25, %0) <{memory_config = #ttnn.memory_config<#dram, <<2x4>>, >}> : (tensor<64x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> + "ttnn.deallocate"(%25) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %27 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout1> + %28 = "ttnn.to_device"(%27, %0) <{memory_config = #ttnn.memory_config<#dram, <<2x4>>, >}> : (tensor<64x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> + "ttnn.deallocate"(%27) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %29 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout2> + // CHECK-NOT: %[[C:.*]] = "ttnn.eq"[[C:.*]] + %30 = "ttnn.eq"(%26, %28, %29) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout2> + "ttnn.deallocate"(%29) <{force = false}> : (tensor<64x128xf32, #ttnn_layout2>) -> () + "ttnn.deallocate"(%28) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + "ttnn.deallocate"(%26) <{force = false}> : (tensor<64x128xf32, #ttnn_layout1>) -> () + %31 = "ttnn.from_device"(%6) : (tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout> + "ttnn.deallocate"(%5) <{force = false}> : (tensor<64x128xf32, #ttnn_layout2>) -> () + %32 = "ttnn.to_layout"(%31) <{layout = #ttnn.layout}> : (tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout> + "ttnn.deallocate"(%31) <{force = false}> : (tensor<64x128xf32, #ttnn_layout>) -> () + return %32 : tensor<64x128xf32, #ttnn_layout> + } +} diff --git a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir index d1e846bd6a..112a941a81 100644 --- a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir +++ b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir @@ -2,7 +2,7 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #dram>, interleaved> + // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #dram>, > // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_1:.*]]> diff --git a/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir b/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir new file mode 100644 index 0000000000..35b4d90634 --- /dev/null +++ b/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir @@ -0,0 +1,49 @@ +// RUN: ttmlir-opt --ttir-load-system-desc="path=%system_desc_path%" %s > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +// TODO: this is a workaround for compiler assuming input tensors are always on host. The ideal is to directly compile ttir graphs. +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#system_memory = #ttnn.buffer_type +#dram = #ttnn.buffer_type +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #system_memory>> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, > +#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #dram>, > + +module attributes {tt.device = #device} { + func.func @add(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2> + %4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2> + %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout> + %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout> + return %6 : tensor<64x128xbf16, #ttnn_layout> + } +} + +module attributes {tt.device = #device} { + func.func @multiply(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2> + %4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2> + %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout> + %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout> + return %6 : tensor<64x128xbf16, #ttnn_layout> + } +} + +module attributes {tt.device = #device} { + func.func @subtract(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2> + %4 = "ttnn.subtract"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2> + %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout> + %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout> + return %6 : tensor<64x128xbf16, #ttnn_layout> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Constant/constant_bf16.mlir b/test/ttmlir/Silicon/StableHLO/Constant/constant_bf16.mlir index 1a24e07595..636ea27167 100644 --- a/test/ttmlir/Silicon/StableHLO/Constant/constant_bf16.mlir +++ b/test/ttmlir/Silicon/StableHLO/Constant/constant_bf16.mlir @@ -18,7 +18,7 @@ module @jit_constant attributes {} { func.func public @test_bfloat16_scalar_empty() -> tensor { // CHECK-LABEL: func.func public @test_bfloat16_scalar_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<1xbf16 %0 = stablehlo.constant dense<0.0> : tensor return %0 : tensor @@ -26,7 +26,7 @@ module @jit_constant attributes {} { func.func public @test_bfloat16_empty() -> tensor<64x128xbf16> { // CHECK-LABEL: func.func public @test_bfloat16_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<64x128xbf16 %0 = stablehlo.constant dense<0.0> : tensor<64x128xbf16> return %0 : tensor<64x128xbf16> diff --git a/test/ttmlir/Silicon/StableHLO/Constant/constant_bool.mlir b/test/ttmlir/Silicon/StableHLO/Constant/constant_bool.mlir index 0c51294e3a..6486ff99c6 100644 --- a/test/ttmlir/Silicon/StableHLO/Constant/constant_bool.mlir +++ b/test/ttmlir/Silicon/StableHLO/Constant/constant_bool.mlir @@ -18,7 +18,7 @@ module @jit_constant attributes {} { func.func public @test_boolean_scalar_empty() -> tensor { // CHECK-LABEL: func.func public @test_boolean_scalar_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<1xbf16 %0 = stablehlo.constant dense : tensor return %0 : tensor @@ -26,7 +26,7 @@ module @jit_constant attributes {} { func.func public @test_boolean_empty() -> tensor<64x128xi1> { // CHECK-LABEL: func.func public @test_boolean_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<64x128xbf16 %0 = stablehlo.constant dense : tensor<64x128xi1> return %0 : tensor<64x128xi1> diff --git a/test/ttmlir/Silicon/StableHLO/Constant/constant_f32.mlir b/test/ttmlir/Silicon/StableHLO/Constant/constant_f32.mlir index 5a29facc78..3fecd90fb0 100644 --- a/test/ttmlir/Silicon/StableHLO/Constant/constant_f32.mlir +++ b/test/ttmlir/Silicon/StableHLO/Constant/constant_f32.mlir @@ -18,7 +18,7 @@ module @jit_constant attributes {} { func.func public @test_float_scalar_empty() -> tensor { // CHECK-LABEL: func.func public @test_float_scalar_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<1xf32 %0 = stablehlo.constant dense<0.0> : tensor return %0 : tensor @@ -26,7 +26,7 @@ module @jit_constant attributes {} { func.func public @test_float_empty() -> tensor<64x128xf32> { // CHECK-LABEL: func.func public @test_float_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<64x128xf32 %0 = stablehlo.constant dense<0.0> : tensor<64x128xf32> return %0 : tensor<64x128xf32> diff --git a/test/ttmlir/Silicon/StableHLO/Constant/constant_f64.mlir b/test/ttmlir/Silicon/StableHLO/Constant/constant_f64.mlir index cc39178165..c286745a09 100644 --- a/test/ttmlir/Silicon/StableHLO/Constant/constant_f64.mlir +++ b/test/ttmlir/Silicon/StableHLO/Constant/constant_f64.mlir @@ -18,7 +18,7 @@ module @jit_constant attributes {} { func.func public @test_f64_scalar_empty() -> tensor { // CHECK-LABEL: func.func public @test_f64_scalar_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<1xf32 %0 = stablehlo.constant dense<0.0> : tensor return %0 : tensor @@ -26,7 +26,7 @@ module @jit_constant attributes {} { func.func public @test_f64_empty() -> tensor<64x128xf64> { // CHECK-LABEL: func.func public @test_f64_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<64x128xf32 %0 = stablehlo.constant dense<0.0> : tensor<64x128xf64> return %0 : tensor<64x128xf64> diff --git a/test/ttmlir/Silicon/StableHLO/Constant/constant_i16.mlir b/test/ttmlir/Silicon/StableHLO/Constant/constant_i16.mlir index 8f4dc247f1..792cdc9d0f 100644 --- a/test/ttmlir/Silicon/StableHLO/Constant/constant_i16.mlir +++ b/test/ttmlir/Silicon/StableHLO/Constant/constant_i16.mlir @@ -18,7 +18,7 @@ module @jit_constant attributes {} { func.func public @test_int16_scalar_empty() -> tensor { // CHECK-LABEL: func.func public @test_int16_scalar_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<1xi16 %0 = stablehlo.constant dense<0> : tensor return %0 : tensor @@ -26,7 +26,7 @@ module @jit_constant attributes {} { func.func public @test_int16_empty() -> tensor<64x128xi16> { // CHECK-LABEL: func.func public @test_int16_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<64x128xi16 %0 = stablehlo.constant dense<0> : tensor<64x128xi16> return %0 : tensor<64x128xi16> diff --git a/test/ttmlir/Silicon/StableHLO/Constant/constant_i32.mlir b/test/ttmlir/Silicon/StableHLO/Constant/constant_i32.mlir index b5c73da0b9..813b08bcf8 100644 --- a/test/ttmlir/Silicon/StableHLO/Constant/constant_i32.mlir +++ b/test/ttmlir/Silicon/StableHLO/Constant/constant_i32.mlir @@ -18,7 +18,7 @@ module @jit_constant attributes {} { func.func public @test_int32_scalar_empty() -> tensor { // CHECK-LABEL: func.func public @test_int32_scalar_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<1xi32 %0 = stablehlo.constant dense<0> : tensor return %0 : tensor @@ -26,7 +26,7 @@ module @jit_constant attributes {} { func.func public @test_int32_empty() -> tensor<64x128xi32> { // CHECK-LABEL: func.func public @test_int32_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<64x128xi32 %0 = stablehlo.constant dense<0> : tensor<64x128xi32> return %0 : tensor<64x128xi32> diff --git a/test/ttmlir/Silicon/StableHLO/Constant/constant_i64.mlir b/test/ttmlir/Silicon/StableHLO/Constant/constant_i64.mlir index bf4a3e8cb2..0bcae491b5 100644 --- a/test/ttmlir/Silicon/StableHLO/Constant/constant_i64.mlir +++ b/test/ttmlir/Silicon/StableHLO/Constant/constant_i64.mlir @@ -18,7 +18,7 @@ module @jit_constant attributes {} { func.func public @test_int64_scalar_empty() -> tensor { // CHECK-LABEL: func.func public @test_int64_scalar_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<1xi32 %0 = stablehlo.constant dense<0> : tensor return %0 : tensor @@ -26,7 +26,7 @@ module @jit_constant attributes {} { func.func public @test_int64_empty() -> tensor<64x128xi64> { // CHECK-LABEL: func.func public @test_int64_empty - // CHECK: ttnn.empty + // CHECK: ttnn.full // CHECK-SAME: -> tensor<64x128xi32 %0 = stablehlo.constant dense<0> : tensor<64x128xi64> return %0 : tensor<64x128xi64> diff --git a/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim2.mlir b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim2.mlir new file mode 100644 index 0000000000..d911ec6fe2 --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim2.mlir @@ -0,0 +1,15 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: ttnn.arange + %0 = "stablehlo.iota"() {iota_dimension = 2: i64} : () -> tensor<1x1x32x128xbf16> + %2 = "stablehlo.multiply"(%arg0, %0) : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim3.mlir b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim3.mlir new file mode 100644 index 0000000000..01aa0e91b3 --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim3.mlir @@ -0,0 +1,16 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + %output_shape = stablehlo.constant dense<[1, 1, 32, 128]> : tensor<4xi64> + // CHECK: ttnn.arange + %0 = "stablehlo.dynamic_iota"(%output_shape) {iota_dimension = 3: i64} : (tensor<4xi64>) -> tensor<1x1x32x128xbf16> + %2 = "stablehlo.multiply"(%arg0, %0) : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim2.mlir b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim2.mlir new file mode 100644 index 0000000000..d911ec6fe2 --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim2.mlir @@ -0,0 +1,15 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: ttnn.arange + %0 = "stablehlo.iota"() {iota_dimension = 2: i64} : () -> tensor<1x1x32x128xbf16> + %2 = "stablehlo.multiply"(%arg0, %0) : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim3.mlir b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim3.mlir new file mode 100644 index 0000000000..a231432abc --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim3.mlir @@ -0,0 +1,15 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: ttnn.arange + %0 = "stablehlo.iota"() {iota_dimension = 3: i64} : () -> tensor<1x1x32x128xbf16> + %2 = "stablehlo.multiply"(%arg0, %0) : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Unary/isfinite_op.mlir b/test/ttmlir/Silicon/StableHLO/Unary/isfinite_op.mlir index 04b9f1fefb..35682c8c0b 100644 --- a/test/ttmlir/Silicon/StableHLO/Unary/isfinite_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/Unary/isfinite_op.mlir @@ -7,14 +7,14 @@ // RUN: FileCheck --input-file=%t.mlir %s module @jit_eltwise_isfinite attributes {} { - func.func public @test_isfinite(%arg0: tensor<64x128xf32>) -> tensor<64x128xi1> { + func.func public @test_isfinite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xi1> { // CHECK-LABEL: func.func public @test_isfinite // CHECK: ttnn.empty // CHECK: ttnn.isfinite - // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: -> tensor<64x128xbf16, - %0 = stablehlo.is_finite %arg0 : (tensor<64x128xf32>) -> tensor<64x128xi1> + %0 = stablehlo.is_finite %arg0 : (tensor<64x128xbf16>) -> tensor<64x128xi1> return %0 : tensor<64x128xi1> } } diff --git a/test/ttmlir/Silicon/StableHLO/dot_general_op.mlir b/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_2d.mlir similarity index 82% rename from test/ttmlir/Silicon/StableHLO/dot_general_op.mlir rename to test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_2d.mlir index 57a0bdcd8d..179f112b49 100644 --- a/test/ttmlir/Silicon/StableHLO/dot_general_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_2d.mlir @@ -6,8 +6,8 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // RUN: FileCheck --input-file=%t.mlir %s -module @jit_dot_general attributes {} { - func.func public @test_dot_general(%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x8xf32>) -> tensor<16x8xf32> { +module @jit_dot_general_2d attributes {} { + func.func public @test_dot_general_2d(%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x8xf32>) -> tensor<16x8xf32> { // CHECK-LABEL: func.func public @test_dot_general // CHECK: ttnn.empty // CHECK: ttnn.matmul diff --git a/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_batch_matmul.mlir b/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_batch_matmul.mlir new file mode 100644 index 0000000000..f23ece73ff --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/dot_general/dot_general_op_batch_matmul.mlir @@ -0,0 +1,21 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s + +module @jit_dot_general_4d attributes {} { + func.func public @test_dot_general_4d(%arg0 : tensor<1x128x16x32xf32>, %arg1 : tensor<1x128x32x8xf32>) -> tensor<1x128x16x8xf32> { + // CHECK-LABEL: func.func public @test_dot_general + // CHECK: ttnn.empty + // CHECK: ttnn.matmul + // CHECK-SAME: tensor<1x128x16x32xf32, + // CHECK-SAME: tensor<1x128x32x8xf32, + // CHECK-SAME: tensor<1x128x16x8xf32, + // CHECK-SAME: -> tensor<1x128x16x8xf32 + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] : (tensor<1x128x16x32xf32>, tensor<1x128x32x8xf32>) -> tensor<1x128x16x8xf32> + return %0 : tensor<1x128x16x8xf32> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/gather_op.mlir b/test/ttmlir/Silicon/StableHLO/gather_op.mlir new file mode 100644 index 0000000000..9a4a90b1b6 --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/gather_op.mlir @@ -0,0 +1,45 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RU1N: FileCheck --input-file=%t.mlir %s + +module @jit_gather attributes {} { + func.func public @test_gather_0(%operand: tensor<32000x1024xbf16>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xbf16> { + // CHECK-LABEL: func.func public @test_gather_0 + // CHECK: ttnn.empty + // CHECK: ttnn.embedding + // CHECK-SAME: tensor<1x32xi32, + // CHECK-SAME: tensor<1x32x1024xbf16 + // CHECK-SAME: tensor<32000x1024xbf16, + // CHECK-SAME: -> tensor<1x32x1024xbf16 + %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<32000x1024xbf16>, tensor<1x32xi32>) -> tensor<1x32x1024xbf16> + return %0 : tensor<1x32x1024xbf16> + } + + func.func public @test_gather_1(%operand: tensor<51864x384xbf16>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xbf16> { + // CHECK-LABEL: func.func public @test_gather_1 + // CHECK: ttnn.empty + // CHECK: ttnn.embedding + // CHECK-SAME: tensor<1x2xi32, + // CHECK-SAME: tensor<1x2x384xbf16 + // CHECK-SAME: tensor<51864x384xbf16, + // CHECK-SAME: -> tensor<1x2x384xbf16 + %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<51864x384xbf16>, tensor<1x2xi32>) -> tensor<1x2x384xbf16> + return %0 : tensor<1x2x384xbf16> + } + + func.func public @test_gather_2(%operand: tensor<32128x512xbf16>, %start_indices: tensor<1x15xi64>) -> tensor<1x15x512xbf16> { + // CHECK-LABEL: func.func public @test_gather_2 + // CHECK: ttnn.empty + // CHECK: ttnn.embedding + // CHECK-SAME: tensor<1x16xi32, + // CHECK-SAME: tensor<1x15x512xbf16 + // CHECK-SAME: tensor<32128x512xbf16, + // CHECK-SAME: -> tensor<1x15x512xbf16 + %0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<32128x512xbf16>, tensor<1x15xi64>) -> tensor<1x15x512xbf16> + return %0 : tensor<1x15x512xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/select_op.mlir b/test/ttmlir/Silicon/StableHLO/select_op.mlir index 23b7182ce0..1cdc5e9d05 100644 --- a/test/ttmlir/Silicon/StableHLO/select_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/select_op.mlir @@ -6,23 +6,23 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn module @jit_eltwise_select attributes {} { - func.func public @test_select(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + func.func public @test_select(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK-LABEL: func.func public @test_select // CHECK: tensor.empty // CHECK: [[EQ:{{0-9}}+]] = "ttnn.eq" - // CHECK-SAME: tensor<64x128xf32 - // CHECK-SAME: tensor<64x128xf32 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: -> tensor<64x128xbf16 - %0 = stablehlo.compare EQ, %arg0, %arg1 : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xi1> + %0 = stablehlo.compare EQ, %arg0, %arg1 : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xi1> // CHECK: ttnn.where // CHECK-SAME: [[EQ]] // CHECK-SAME: tensor<64x128xbf16 - // CHECK-SAME: tensor<64x128xf32 - // CHECK-SAME: tensor<64x128xf32 - // CHECK-SAME: tensor<64x128xf32 - // CHECK-SAME: -> tensor<64x128xf32 - %1 = stablehlo.select %0, %arg0, %arg1 : (tensor<64x128xi1>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: -> tensor<64x128xbf16 + %1 = stablehlo.select %0, %arg0, %arg1 : (tensor<64x128xi1>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTMetal/simple_max.mlir b/test/ttmlir/Silicon/TTMetal/simple_max.mlir new file mode 100644 index 0000000000..92bdbe72c7 --- /dev/null +++ b/test/ttmlir/Silicon/TTMetal/simple_max.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttmetal-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttmetal-to-flatbuffer %t.mlir > %t.ttm + +#any_device = #tt.operand_constraint + +func.func @maximum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] + %1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir b/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir index 1674ae0d32..cdde621c2a 100644 --- a/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir +++ b/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir @@ -1,8 +1,8 @@ // RUN: ttmlir-opt --ttir-to-ttmetal-backend-pipeline="system-desc-path=%system_desc_path%" %s | FileCheck %s #any_device = #tt.operand_constraint #l1_ = #tt.memory_space -#layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <4x4>, memref<64x96xf32, #l1_>> -#layout2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <4x1>, memref<64x32xf32, #l1_>> +#layout1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x4>, memref<64x96xf32, #l1_>> +#layout2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x1>, memref<64x32xf32, #l1_>> func.func @reduceW(%arg0: tensor<256x384xf32, #layout1>) -> tensor<256x32xf32, #layout2> { %0 = tensor.empty() : tensor<256x32xf32, #layout2> @@ -15,7 +15,7 @@ func.func @reduceW(%arg0: tensor<256x384xf32, #layout1>) -> tensor<256x32xf32, # return %1 : tensor<256x32xf32, #layout2> } -#layout3 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<32x96xf32, #l1_>> +#layout3 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<32x96xf32, #l1_>> func.func @reduceH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x384xf32, #layout3> { %0 = tensor.empty() : tensor<32x384xf32, #layout3> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] @@ -27,7 +27,7 @@ func.func @reduceH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x384xf32, # return %1 : tensor<32x384xf32, #layout3> } -#layout4 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<32x32xf32, #l1_>> +#layout4 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<32x32xf32, #l1_>> func.func @reduceWH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x32xf32, #layout4> { %0 = tensor.empty() : tensor<32x32xf32, #layout4> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] diff --git a/test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir b/test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir index 64cf5f57a6..d7d3cea1dd 100644 --- a/test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir +++ b/test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir @@ -4,10 +4,10 @@ #l1_ = #tt.memory_space -#untilized = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> -#tilized = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32 x 32, f32>, #l1_>> -#tilized2x2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32 x 32, f32>, #l1_>> -#untilized2x2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>> +#untilized = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> +#tilized = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32 x 32, f32>, #l1_>> +#tilized2x2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<1x2x!tt.tile<32 x 32, f32>, #l1_>> +#untilized2x2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<32x64xf32, #l1_>> func.func @tilize_reblock_2D(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64x128xf32, #untilized2x2> { // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32, #tilized> @@ -25,10 +25,10 @@ func.func @tilize_reblock_2D(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64 } -#untilized4D = #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, memref<384x128xf32, #l1_>> -#tilized4D = #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, memref<12x4x!tt.tile<32 x 32, f32>, #l1_>> -#tilized4D_2x2 = #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x2>, memref<6x2x!tt.tile<32 x 32, f32>, #l1_>> -#untilized4D_2x2 = #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x2>, memref<192x64xf32, #l1_>> +#untilized4D = #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, memref<384x128xf32, #l1_>> +#tilized4D = #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <1x1>, memref<12x4x!tt.tile<32 x 32, f32>, #l1_>> +#tilized4D_2x2 = #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x2>, memref<6x2x!tt.tile<32 x 32, f32>, #l1_>> +#untilized4D_2x2 = #tt.metal_layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), undef, <2x2>, memref<192x64xf32, #l1_>> func.func @tilize_reblock_4D(%arg0: tensor<2x3x64x128xf32, #untilized4D>) -> tensor<2x3x64x128xf32, #untilized4D_2x2> { // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] %0 = tensor.empty() : tensor<2x3x64x128xf32, #tilized4D> @@ -48,10 +48,10 @@ func.func @tilize_reblock_4D(%arg0: tensor<2x3x64x128xf32, #untilized4D>) -> ten return %5 : tensor<2x3x64x128xf32, #untilized4D_2x2> } -#untilized_big = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<96x192xf32, #l1_>> -#tilized_big = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<3x6x!tt.tile<32 x 32, f32>, #l1_>> -#tilized_big_3x2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <3x2>, memref<1x3x!tt.tile<32 x 32, f32>, #l1_>> -#tilized_big_3x6 = #tt.layout<(d0, d1) -> (d0, d1), undef, <3x6>, memref<1x1x!tt.tile<32 x 32, f32>, #l1_>> +#untilized_big = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<96x192xf32, #l1_>> +#tilized_big = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<3x6x!tt.tile<32 x 32, f32>, #l1_>> +#tilized_big_3x2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <3x2>, memref<1x3x!tt.tile<32 x 32, f32>, #l1_>> +#tilized_big_3x6 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <3x6>, memref<1x1x!tt.tile<32 x 32, f32>, #l1_>> func.func @tilize_reblock_big(%arg0: tensor<96x192xf32, #untilized_big>) -> tensor<96x192xf32, #untilized_big> { // move to tilized 1x1 // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] diff --git a/test/ttmlir/Silicon/TTMetal/to_layout.mlir b/test/ttmlir/Silicon/TTMetal/to_layout.mlir index 015e651750..e5318c6c1d 100644 --- a/test/ttmlir/Silicon/TTMetal/to_layout.mlir +++ b/test/ttmlir/Silicon/TTMetal/to_layout.mlir @@ -5,8 +5,8 @@ #l1_ = #tt.memory_space #dram = #tt.memory_space -#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<4x16xf32, #l1_>> -#layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<2x8xf32, #l1_>> +#layout = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<4x16xf32, #l1_>> +#layout1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<2x8xf32, #l1_>> func.func @simple(%arg0: tensor<4x16xf32, #layout>) -> tensor<4x16xf32, #layout1> { %0 = tensor.empty() : tensor<4x16xf32, #layout1> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] @@ -14,8 +14,8 @@ func.func @simple(%arg0: tensor<4x16xf32, #layout>) -> tensor<4x16xf32, #layout1 return %1 : tensor<4x16xf32, #layout1> } -#untilized = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> -#tilized = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32 x 32, f32>, #l1_>> +#untilized = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> +#tilized = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32 x 32, f32>, #l1_>> func.func @tilize(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64x128xf32, #untilized> { %0 = tensor.empty() : tensor<64x128xf32, #tilized> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] @@ -26,11 +26,11 @@ func.func @tilize(%arg0: tensor<64x128xf32, #untilized>) -> tensor<64x128xf32, # return %3 : tensor<64x128xf32, #untilized> } -#untilized_dram = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #dram>> -#untilized_l1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #l1_>> -#untilized2x2_dram = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #dram>> -#untilized2x2_l1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #l1_>> -#untilized1x4_l1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<16x16xf32, #l1_>> +#untilized_dram = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #dram>> +#untilized_l1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<16x64xf32, #l1_>> +#untilized2x2_dram = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #dram>> +#untilized2x2_l1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <2x2>, memref<8x32xf32, #l1_>> +#untilized1x4_l1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x4>, memref<16x16xf32, #l1_>> func.func @dram_to_l1(%arg0: tensor<16x64xf32, #untilized_dram>) -> tensor<16x64xf32, #untilized_l1> { %0 = tensor.empty() : tensor<16x64xf32, #untilized_l1> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] diff --git a/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir new file mode 100644 index 0000000000..f3affc69d4 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// UNSUPPORTED: true +// https://github.com/tenstorrent/tt-mlir/issues/1448 +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: %[[C:.*]] = "ttnn.arange"[[C:.*]] + %0 = "ttir.arange"() <{start = 0: si64, end = 64: si64, step = 2: si64, arange_dimension = 2: i64}> : () -> tensor<1x1x32x128xbf16> + %1 = tensor.empty() : tensor<1x1x32x128xbf16> + %2 = "ttir.multiply"(%arg0, %0, %1) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir new file mode 100644 index 0000000000..196e757096 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: %[[C:.*]] = "ttnn.arange"[[C:.*]] + %0 = "ttir.arange"() <{start = 0: si64, end = 128: si64, step = 1: si64, arange_dimension = 3: i64}> : () -> tensor<1x1x32x128xbf16> + %1 = tensor.empty() : tensor<1x1x32x128xbf16> + %2 = "ttir.multiply"(%arg0, %0, %1) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir b/test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir index 84e424cbc8..33645730ab 100644 --- a/test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir +++ b/test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir @@ -1,7 +1,7 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint +#any_device = #tt.operand_constraint func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> diff --git a/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir b/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir new file mode 100644 index 0000000000..3f304969c8 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir @@ -0,0 +1,16 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device = #tt.operand_constraint + +func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = tensor.empty() : tensor<32x32xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %1 : tensor<32x32xbf16> +} + +func.func @subtract(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = tensor.empty() : tensor<32x32xbf16> + %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %1 : tensor<32x32xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/kv_cache/fill_cache.mlir b/test/ttmlir/Silicon/TTNN/kv_cache/fill_cache.mlir new file mode 100644 index 0000000000..67bf8387b1 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/kv_cache/fill_cache.mlir @@ -0,0 +1,14 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +module { + func.func @forward(%arg0: tensor<1x32x64x512xbf16>, %arg1: tensor<1x32x3x512xbf16>) -> tensor<1x32x64x512xbf16> { + // CHECK: "ttnn.fill_cache"[[C:.*]] + %1 = "ttir.fill_cache"(%arg0, %arg1) <{batch_offset = 0: i32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x64x512xbf16>, tensor<1x32x3x512xbf16>) -> tensor<1x32x64x512xbf16> + %cst = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<1x32x64x512xbf16>}> : () -> tensor<1x32x64x512xbf16> + %addition_dps = tensor.empty() : tensor<1x32x64x512xbf16> + %2 = "ttir.add"(%1, %cst, %addition_dps) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x64x512xbf16>, tensor<1x32x64x512xbf16>, tensor<1x32x64x512xbf16>) -> tensor<1x32x64x512xbf16> + return %2 : tensor<1x32x64x512xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/kv_cache/update_cache.mlir b/test/ttmlir/Silicon/TTNN/kv_cache/update_cache.mlir new file mode 100644 index 0000000000..63a08b3023 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/kv_cache/update_cache.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +module { + func.func @forward(%arg0: tensor<1x32x64x512xbf16>, %arg1: tensor<1x32x1x512xbf16>) -> tensor<1x32x64x512xbf16> { + // CHECK: "ttnn.update_cache"[[C:.*]] + %update_index = "ttir.constant"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %1 = "ttir.update_cache"(%arg0, %arg1, %update_index) <{batch_offset = 0: i32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x64x512xbf16>, tensor<1x32x1x512xbf16>, tensor<1xi32>) -> tensor<1x32x64x512xbf16> + %cst = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<1x32x64x512xbf16>}> : () -> tensor<1x32x64x512xbf16> + %addition_dps = tensor.empty() : tensor<1x32x64x512xbf16> + %2 = "ttir.add"(%1, %cst, %addition_dps) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x64x512xbf16>, tensor<1x32x64x512xbf16>, tensor<1x32x64x512xbf16>) -> tensor<1x32x64x512xbf16> + return %2 : tensor<1x32x64x512xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/optimizer/large_tensors.mlir b/test/ttmlir/Silicon/TTNN/optimizer/large_tensors.mlir deleted file mode 100644 index fb71dae8d7..0000000000 --- a/test/ttmlir/Silicon/TTNN/optimizer/large_tensors.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -module attributes {} { - func.func @forward(%arg0: tensor<8192x8192xbf16>, %arg1: tensor<8192x8192xbf16>, %arg2: tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> { - // CHECK: #[[LAYOUT_2:ttnn_layout2]] = #ttnn.ttnn_layout<{{.*}}, memref<{{.*}}, #dram>, {{.*}}> - %0 = tensor.empty() : tensor<8192x8192xbf16> - // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x8192xbf16, #[[LAYOUT_2]]> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> - %2 = tensor.empty() : tensor<8192x8192xbf16> - // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x8192xbf16, #[[LAYOUT_2]]> - %3 = "ttir.add"(%1, %arg2, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> - %4 = tensor.empty() : tensor<8192x8192xbf16> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<8192x8192xbf16, #[[LAYOUT_2]]> - %7 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> - return %7 : tensor<8192x8192xbf16> - } -} diff --git a/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir b/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir index bccd45fcec..3cf9c45817 100644 --- a/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir +++ b/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir @@ -5,8 +5,8 @@ #loc = loc("MNISTLinear":4294967295:0) module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { - // CHECK-DAG: #[[LAYOUT_10:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x8>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, width_sharded> - // CHECK-DAG: #[[LAYOUT_11:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, width_sharded> + // CHECK-DAG: #[[LAYOUT_10:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x8>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_11:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_10]]> %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir index ba995925d5..0193ec36b1 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir @@ -4,8 +4,8 @@ #loc = loc("MNISTLinear":4294967295:0) module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { - // CHECK: #[[LAYOUT_10:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, block_sharded> - // CHECK: #[[LAYOUT_11:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, block_sharded> + // CHECK: #[[LAYOUT_10:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x8>, memref<1x32xf32, #l1_>, block_sharded> + // CHECK: #[[LAYOUT_11:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, block_sharded> %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_10]]> %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ceil.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ceil.mlir index c31c789f44..2e7f55428c 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ceil.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ceil.mlir @@ -5,9 +5,9 @@ #any_device_tile = #tt.operand_constraint func.func @ceil(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.ceil"[[C:.*]] + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.ceil"(%{{[0-9]+}}, [[VAL0]]) %1 = "ttir.ceil"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_cosine.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_cosine.mlir index 91a7fea47d..ede823439e 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_cosine.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_cosine.mlir @@ -5,9 +5,9 @@ #any_device_tile = #tt.operand_constraint func.func @cosine(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.cos"[[C:.*]] + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.cos"(%{{[0-9]+}}, [[VAL0]]) %1 = "ttir.cos"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_isfinite.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_isfinite.mlir index ce0146be40..f1489a5ebd 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_isfinite.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_isfinite.mlir @@ -4,14 +4,14 @@ #any_device = #tt.operand_constraint #any_device_tile = #tt.operand_constraint -func.func @is_finite(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { +func.func @is_finite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty" // CHECK-SAME: [[TENSOR:tensor<64x128xbf16,]] %0 = tensor.empty() : tensor<64x128xbf16> // CHECK: %[[C:.*]] = "ttnn.isfinite" - // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_le.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_le.mlir deleted file mode 100644 index 79de8c062d..0000000000 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_le.mlir +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - -module attributes {} { - func.func @less_equal(%arg0: tensor<13x31xf32>, %arg1: tensor<13x31xf32>) -> tensor<13x31xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty - // CHECK-SAME: [[TENSOR:tensor<13x31xf32,]] - %0 = tensor.empty() : tensor<13x31xf32> - // CHECK: %[[C:.*]] = "ttnn.le" - // CHECK-SAME: [[TENSOR]] - // CHECK-SAME: [[TENSOR]] - // CHECK-SAME: [[TENSOR]] - // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.le"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> - return %1 : tensor<13x31xf32> - } -} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir new file mode 100644 index 0000000000..6da5d3910e --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir @@ -0,0 +1,20 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device_tile = #tt.operand_constraint +module { + func.func @linear(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log.mlir index b1ca157c61..b3de1bba4d 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log.mlir @@ -4,10 +4,10 @@ #any_device = #tt.operand_constraint #any_device_tile = #tt.operand_constraint -func.func @sqrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] +func.func @log(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.log"[[C:.*]] + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.log"(%{{[0-9]+}}, [[VAL0]]) %1 = "ttir.log"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_matmul.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_matmul.mlir index e1c672a6ec..9c240b0ab7 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_matmul.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_matmul.mlir @@ -2,7 +2,7 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn #any_device_tile = #tt.operand_constraint -// CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, interleaved> +// CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, > module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sine.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sine.mlir index e72d57ffa9..36f71d8e6a 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sine.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sine.mlir @@ -5,9 +5,9 @@ #any_device_tile = #tt.operand_constraint func.func @sine(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.sin"[[C:.*]] + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.sin"(%{{[0-9]+}}, [[VAL0]]) %1 = "ttir.sin"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tan.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tan.mlir new file mode 100644 index 0000000000..aa7b972983 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tan.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @tan(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.tan"(%{{[0-9]+}}, [[VAL0]]) + %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tanh.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tanh.mlir new file mode 100644 index 0000000000..ecb7266c96 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tanh.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @tanh(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.tanh"(%{{[0-9]+}}, [[VAL0]]) + %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir index 3bed0528c6..647f94e61e 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir @@ -4,13 +4,13 @@ #any_device = #tt.operand_constraint #any_device_tile = #tt.operand_constraint -func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { +func.func @test_where(%arg0: tensor<13x37xbf16>, %arg1: tensor<13x37xbf16>) -> tensor<13x37xbf16> { %0 = tensor.empty() : tensor<13x37xbf16> - %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xbf16>) -> tensor<13x37xbf16> - %2 = tensor.empty() : tensor<13x37xf32> - %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> + %2 = tensor.empty() : tensor<13x37xbf16> + %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) - return %3 : tensor<13x37xf32> + return %3 : tensor<13x37xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/simple_broadcast.mlir b/test/ttmlir/Silicon/TTNN/simple_broadcast.mlir deleted file mode 100644 index 1d88725d1d..0000000000 --- a/test/ttmlir/Silicon/TTNN/simple_broadcast.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint - -func.func public @broadcast() -> (tensor<32xf32>) { - %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> - %1 = tensor.empty() : tensor<32xf32> - %2 = "ttir.broadcast"(%0, %1) <{dimension = [0], operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1xf32>, tensor<32xf32>) -> tensor<32xf32> - %3 = tensor.empty() : tensor<32xf32> - %4 = "ttir.broadcast"(%2, %3) <{dimension = [0], operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> - // CHECK-NOT: %[[C:.*]] = "ttir.broadcast"[[C:.*]] - return %4 : tensor<32xf32> -} diff --git a/test/ttmlir/Silicon/TTNN/simple_constant.mlir b/test/ttmlir/Silicon/TTNN/simple_constant.mlir index 4f33870c0f..35728f0a93 100644 --- a/test/ttmlir/Silicon/TTNN/simple_constant.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_constant.mlir @@ -4,19 +4,19 @@ module @sysmem_creation attributes {} { func.func @test_empty_int() -> tensor<64x128xi32> { %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi32>}> : () -> tensor<64x128xi32> - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] return %0 : tensor<64x128xi32> } func.func @test_empty_float() -> tensor<64x128xf32> { %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] return %0 : tensor<64x128xf32> } func.func @test_empty_float_scalar() -> tensor<1x1xf32> { %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] return %0 : tensor<1x1xf32> } diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index 976f2867db..b0fb94cc6d 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -14,7 +14,8 @@ func.func @add(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<6 func.func @ceil(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { %0 = tensor.empty() : tensor<32x32xf32> - // CHECK: %[[C:.*]] = "ttnn.ceil"[[C:.*]] + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.ceil"(%{{[0-9]+}}, [[VAL0]]) %1 = "ttir.ceil"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> return %1 : tensor<32x32xf32> } @@ -40,7 +41,8 @@ func.func @concat(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor< func.func @cosine(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { %0 = tensor.empty() : tensor<32x32xf32> - // CHECK: %[[C:.*]] = "ttnn.cos"[[C:.*]] + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.cos"(%{{[0-9]+}}, [[VAL0]]) %1 = "ttir.cos"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> return %1 : tensor<32x32xf32> } @@ -65,15 +67,15 @@ func.func @floor(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { return %1 : tensor<64x128xf32> } -func.func @is_finite(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { +func.func @is_finite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty" // CHECK-SAME: [[TENSOR:tensor<64x128xbf16,]] %0 = tensor.empty() : tensor<64x128xbf16> // CHECK: %[[C:.*]] = "ttnn.isfinite" - // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } @@ -193,7 +195,8 @@ func.func @sqrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { func.func @sine(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { %0 = tensor.empty() : tensor<32x32xf32> - // CHECK: %[[C:.*]] = "ttnn.sin"[[C:.*]] + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.sin"(%{{[0-9]+}}, [[VAL0]]) %1 = "ttir.sin"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> return %1 : tensor<32x32xf32> } @@ -278,15 +281,15 @@ func.func @get_dimension_size(%arg0: tensor<13x21x3xf32>) -> tensor<1xi32> { // CHECK: return [[VAL]] : tensor<1xi32, {{.*}}> } -func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { +func.func @test_where(%arg0: tensor<13x37xbf16>, %arg1: tensor<13x37xbf16>) -> tensor<13x37xbf16> { %0 = tensor.empty() : tensor<13x37xbf16> - %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xbf16>) -> tensor<13x37xbf16> - %2 = tensor.empty() : tensor<13x37xf32> - %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> + %2 = tensor.empty() : tensor<13x37xbf16> + %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) - return %3 : tensor<13x37xf32> + return %3 : tensor<13x37xbf16> } func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { @@ -301,8 +304,34 @@ func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { return %1 : tensor<64x128xf32> } +func.func @tan(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.tan"(%{{[0-9]+}}, [[VAL0]]) + %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} + +func.func @tanh(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.tanh"(%{{[0-9]+}}, [[VAL0]]) + %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} + func.func @addint32(%arg0: tensor<64x128xi32>, %arg1: tensor<64x128xi32>) -> tensor<64x128xi32> { %0 = tensor.empty() : tensor<64x128xi32> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xi32>, tensor<64x128xi32>, tensor<64x128xi32>) -> tensor<64x128xi32> return %1 : tensor<64x128xi32> } + +func.func @scatter(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { + %0 = tensor.empty() : tensor<1x3x320x320xf32> + %1 = tensor.empty() : tensor<1x1xi32> + %2 = "ttir.scatter"(%arg0, %1, %arg1, %0) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile], scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array}> ({ + ^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>): + "ttir.yield"(%arg4) : (tensor<1xf32>) -> () + }) : (tensor<1x3x320x320xf32>, tensor<1x1xi32>, tensor<1x3x32x32xf32>, tensor<1x3x320x320xf32>) -> tensor<1x3x320x320xf32> + return %2 : tensor<1x3x320x320xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/simple_linear.mlir b/test/ttmlir/Silicon/TTNN/simple_linear.mlir new file mode 100644 index 0000000000..f53de38cf3 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/simple_linear.mlir @@ -0,0 +1,33 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device_tile = #tt.operand_constraint +module { + func.func @simple_linear_without_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @simple_linear_with_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/simple_matmul.mlir b/test/ttmlir/Silicon/TTNN/simple_matmul.mlir index e1c672a6ec..9c240b0ab7 100644 --- a/test/ttmlir/Silicon/TTNN/simple_matmul.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_matmul.mlir @@ -2,7 +2,7 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn #any_device_tile = #tt.operand_constraint -// CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, interleaved> +// CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, > module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> diff --git a/test/unittests/Optimizer/CMakeLists.txt b/test/unittests/Optimizer/CMakeLists.txt index 4e6ee799a7..b05c8ae294 100644 --- a/test/unittests/Optimizer/CMakeLists.txt +++ b/test/unittests/Optimizer/CMakeLists.txt @@ -1,11 +1,13 @@ add_mlir_unittest(OptimizerTests TestShardSolver.cpp TestOptimizerOverrides.cpp + TestL1InterleavedPolicy.cpp ) target_link_libraries(OptimizerTests PRIVATE MLIR MLIRTTDialect + MLIRTTNNAnalysis MLIRTTNNPipelines ) diff --git a/test/unittests/Optimizer/TestL1InterleavedPolicy.cpp b/test/unittests/Optimizer/TestL1InterleavedPolicy.cpp new file mode 100644 index 0000000000..b09b65245d --- /dev/null +++ b/test/unittests/Optimizer/TestL1InterleavedPolicy.cpp @@ -0,0 +1,195 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/SmallVector.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" + +#include "ttmlir/Dialect/TTNN/IR/TTNN.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" + +#include "ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h" + +using namespace mlir::tt::ttnn; + +constexpr int TensorDimX = 128; +constexpr int TensorDimY = 128; + +class L1InterleavedPolicyBase : public ::testing::Test { +public: + mlir::MLIRContext context; + mlir::OwningOpRef module; + mlir::OpBuilder builder = mlir::OpBuilder(&context); + mlir::func::FuncOp func; + mlir::tt::DeviceAttr deviceAttr; + + using OpMemSpec = L1InterleavedPolicy::OpMemSpec; + using OpConfig = L1InterleavedPolicy::OpConfig; + using L1Usage = L1InterleavedPolicy::L1Usage; + + void SetUp() override { + context.loadDialect(); + module = mlir::ModuleOp::create(builder.getUnknownLoc()); + builder.setInsertionPointToStart(&module->getBodyRegion().front()); + createFuncOp(); + deviceAttr = mlir::tt::getCurrentScopeDevice(func); + } + + llvm::SmallVector getTensorShape() { + return {TensorDimX, TensorDimY}; + } + + mlir::RankedTensorType getTensorRankedType() { + return mlir::RankedTensorType::get(getTensorShape(), builder.getF32Type()); + } + + mlir::Value createEmptyTensor() { + ShapeAttr shapeAttr = ShapeAttr::get(&context, getTensorShape()); + return builder.create(builder.getUnknownLoc(), + getTensorRankedType(), nullptr, shapeAttr, + nullptr, nullptr, nullptr); + } + + mlir::func::FuncOp createFuncOp() { + mlir::SmallVector input; + input.push_back(getTensorRankedType()); + + mlir::SmallVector output; + output.push_back(getTensorRankedType()); + + auto funcType = builder.getType( + mlir::TypeRange(input), mlir::TypeRange(output)); + func = builder.create(builder.getUnknownLoc(), "test", + funcType); + + mlir::Block *block = func.addEntryBlock(); + block->addArgument(getTensorRankedType(), builder.getUnknownLoc()); + block->addArgument(getTensorRankedType(), builder.getUnknownLoc()); + + builder.setInsertionPointToStart(block); + + return func; + } + + void addLayoutForOp(mlir::Operation *op, + llvm::DenseMap> &legalLayouts, + BufferType memorySpace, + TensorMemoryLayout tensorMemoryLayout) { + TensorMemoryLayoutAttr tensorMemoryLayoutAttr = + TensorMemoryLayoutAttr::get(&context, tensorMemoryLayout); + if (legalLayouts.find(op) == legalLayouts.end()) { + legalLayouts[op] = std::vector{TTNNLayoutAttr::get( + &context, getTensorRankedType().getShape(), builder.getF32Type(), + memorySpace, mlir::tt::GridAttr::get(&context, {8, 8}), + tensorMemoryLayoutAttr)}; + } else { + legalLayouts[op].push_back(TTNNLayoutAttr::get( + &context, getTensorRankedType().getShape(), builder.getF32Type(), + memorySpace, mlir::tt::GridAttr::get(&context, {8, 8}), + tensorMemoryLayoutAttr)); + } + } + + void prepareOpForGreedyConfigPicker( + mlir::Operation *op, uint64_t outputL1Usage, uint64_t requiredL1Usage, + llvm::DenseMap> + &legalLayouts, + llvm::DenseMap &opsL1Usage) { + + // Add two legal layouts for the op with different buffer + // types: DRAM and L1. + addLayoutForOp(op, legalLayouts, BufferType::DRAM, + TensorMemoryLayout::Interleaved); + addLayoutForOp(op, legalLayouts, BufferType::L1, + TensorMemoryLayout::Interleaved); + + L1Usage l1Usage; + l1Usage.outputL1Usage = outputL1Usage; + l1Usage.requiredL1Usage = requiredL1Usage; + opsL1Usage[op] = l1Usage; + } + + void TearDown() override {} +}; + +TEST_F(L1InterleavedPolicyBase, VerifyGreedyPolicy) { + std::vector l1ChainConfigs; + llvm::DenseMap> legalLayouts; + llvm::DenseMap> + schedule; + llvm::DenseMap opsL1Usage; + constexpr uint64_t usableL1CacheSize = 15; + + // Create operand A + mlir::Value dest = createEmptyTensor(); + mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0); + mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1); + mlir::Operation *opA = + builder.create(builder.getUnknownLoc(), lhs, rhs, dest); + uint64_t outputL1Usage = 2; + uint64_t requiredL1Usage = 8; + prepareOpForGreedyConfigPicker(opA, outputL1Usage, requiredL1Usage, + legalLayouts, opsL1Usage); + + // Create operand B + dest = createEmptyTensor(); + lhs = func.getBody().getBlocks().front().getArgument(0); + rhs = func.getBody().getBlocks().front().getArgument(1); + mlir::Operation *opB = + builder.create(builder.getUnknownLoc(), lhs, rhs, dest); + outputL1Usage = 3; + requiredL1Usage = 7; + prepareOpForGreedyConfigPicker(opB, outputL1Usage, requiredL1Usage, + legalLayouts, opsL1Usage); + + // Create operand C + dest = createEmptyTensor(); + lhs = func.getBody().getBlocks().front().getArgument(0); + rhs = func.getBody().getBlocks().front().getArgument(1); + mlir::Operation *opC = + builder.create(builder.getUnknownLoc(), lhs, rhs, dest); + outputL1Usage = 1; + requiredL1Usage = 9; + prepareOpForGreedyConfigPicker(opC, outputL1Usage, requiredL1Usage, + legalLayouts, opsL1Usage); + + // Create base op D + dest = createEmptyTensor(); + lhs = func.getBody().getBlocks().front().getArgument(0); + rhs = func.getBody().getBlocks().front().getArgument(1); + mlir::Operation *opD = + builder.create(builder.getUnknownLoc(), lhs, rhs, dest); + outputL1Usage = 4; + requiredL1Usage = 0; + prepareOpForGreedyConfigPicker(opD, outputL1Usage, requiredL1Usage, + legalLayouts, opsL1Usage); + + // Run greedy config picker policy + L1InterleavedPolicy l1InterleavedPolicy(nullptr, l1ChainConfigs, legalLayouts, + schedule, usableL1CacheSize); + OpConfig greedyConfig = l1InterleavedPolicy.getGreedyConfig(opD, opsL1Usage); + + // Sanity checks + ASSERT_TRUE(greedyConfig.baseOp == opD); + ASSERT_TRUE(greedyConfig.layouts.size() == 4); + ASSERT_TRUE(greedyConfig.precedence.size() == 3); + + // All layouts should be using L1 buffer type + for (const auto &[op, layout] : greedyConfig.layouts) { + ASSERT_TRUE(layout.hasL1BufferType()); + } + + // Precedence order for op D should be: C, A, B + ASSERT_EQ(greedyConfig.precedence[0], opC); + ASSERT_EQ(greedyConfig.precedence[1], opA); + ASSERT_EQ(greedyConfig.precedence[2], opB); +} diff --git a/test/unittests/Optimizer/TestOptimizerOverrides.cpp b/test/unittests/Optimizer/TestOptimizerOverrides.cpp index dbbea2de64..31118262f5 100644 --- a/test/unittests/Optimizer/TestOptimizerOverrides.cpp +++ b/test/unittests/Optimizer/TestOptimizerOverrides.cpp @@ -122,3 +122,424 @@ TEST_F(OutputLayoutOverrideTest, ParseMultipleOps) { ASSERT_TRUE(params2.dataType.has_value()); ASSERT_EQ(params2.dataType.value(), mlir::tt::DataType::Float16); } + +class TestOptimizerOverrideHandler : public ::testing::Test { + +public: + OptimizerOverridesHandler optimizerOverridesHandler; + + void SetUp() override {} + + llvm::StringMap createInputLayoutOverrides() { + + // struct InputLayoutOverrideParams { + // SmallVector operandIdxes; + // }; + + llvm::StringMap inputLayoutOverrides; + + // Create input layout overrides for 3 input overrides. + inputLayoutOverrides["input0"] = createInputLayoutOverrideParams(); + inputLayoutOverrides["input1"] = createInputLayoutOverrideParams(); + inputLayoutOverrides["input2"] = createInputLayoutOverrideParams(); + + return inputLayoutOverrides; + } + + InputLayoutOverrideParams createInputLayoutOverrideParams() { + + InputLayoutOverrideParams inputLayoutOverrideParams; + + // Create input layout override params for 2 operands. + // Their operand indexes are 0 and 1, respectively. + inputLayoutOverrideParams.operandIdxes.push_back(0); + inputLayoutOverrideParams.operandIdxes.push_back(1); + + return inputLayoutOverrideParams; + } + + llvm::StringMap createOutputLayoutOverrides() { + + llvm::StringMap outputLayoutOverrides; + + // Create output layout overrides for 3 output overrides. + outputLayoutOverrides["output0"] = createOutputLayoutOverrideParams_0(); + outputLayoutOverrides["output1"] = createOutputLayoutOverrideParams_1(); + outputLayoutOverrides["output2"] = createOutputLayoutOverrideParams_2(); + + return outputLayoutOverrides; + } + + OutputLayoutOverrideParams createOutputLayoutOverrideParams_0() { + + // struct OutputLayoutOverrideParams { + // SmallVector grid; + // BufferType; + // TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + // Layout memoryLayout; // ROW_MAJOR / TILE + // mlir::tt::DataType dataType; + // }; + + OutputLayoutOverrideParams outputLayoutOverrideParams; + + // Output 0 has + // - grid size 2x2, + // - buffer type dram + // - tensor memory layout interleaved + // - memory layout tile + // - data type fp16. + outputLayoutOverrideParams.grid = llvm::SmallVector({2, 2}); + outputLayoutOverrideParams.bufferType = BufferType::DRAM; + outputLayoutOverrideParams.tensorMemoryLayout = + TensorMemoryLayout::Interleaved; + outputLayoutOverrideParams.memoryLayout = Layout::Tile; + outputLayoutOverrideParams.dataType = mlir::tt::DataType::Float16; + + return outputLayoutOverrideParams; + } + + OutputLayoutOverrideParams createOutputLayoutOverrideParams_1() { + + // struct OutputLayoutOverrideParams { + // SmallVector grid; + // BufferType; + // TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + // Layout memoryLayout; // ROW_MAJOR / TILE + // mlir::tt::DataType dataType; + // }; + + OutputLayoutOverrideParams outputLayoutOverrideParams; + + // Output 1 has + // - grid size 8x4, + // - buffer type l1 + // - tensor memory layout block_sharded + // - memory layout row_major + // - data type fp16. + outputLayoutOverrideParams.grid = llvm::SmallVector({8, 4}); + outputLayoutOverrideParams.bufferType = BufferType::L1; + outputLayoutOverrideParams.tensorMemoryLayout = + TensorMemoryLayout::BlockSharded; + outputLayoutOverrideParams.memoryLayout = Layout::RowMajor; + outputLayoutOverrideParams.dataType = mlir::tt::DataType::Float16; + + return outputLayoutOverrideParams; + } + + OutputLayoutOverrideParams createOutputLayoutOverrideParams_2() { + + // struct OutputLayoutOverrideParams { + // SmallVector grid; + // BufferType; + // TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + // Layout memoryLayout; // ROW_MAJOR / TILE + // mlir::tt::DataType dataType; + // }; + + OutputLayoutOverrideParams outputLayoutOverrideParams; + + // Output 2 has + // - grid size 3x6, + // - buffer type system + // - tensor memory layout height_sharded + // - memory layout tile + // - data type fp16. + outputLayoutOverrideParams.grid = llvm::SmallVector({3, 6}); + outputLayoutOverrideParams.bufferType = BufferType::SystemMemory; + outputLayoutOverrideParams.tensorMemoryLayout = + TensorMemoryLayout::HeightSharded; + outputLayoutOverrideParams.memoryLayout = Layout::Tile; + outputLayoutOverrideParams.dataType = mlir::tt::DataType::Float16; + + return outputLayoutOverrideParams; + } + + bool + compareInputLayoutOverrides(llvm::StringMap in1, + llvm::StringMap in2) { + // Check if the sizes of the two input layout overrides are the same. + if (in1.size() != in2.size()) { + return false; + } + llvm::StringMap::iterator it1; + for (it1 = in1.begin(); it1 != in1.end(); it1++) { + // Check if the two input layout overrides have the same keys. + llvm::StringMap::iterator it2 = + in2.find(it1->getKey()); + if (it2 == in2.end()) { + return false; + } + // Check if the two input layout overrides have the same values. + // The structure InputLayoutOverrideParams has overloaded operators for == + // and !=, so we can compare the objects in this way. + if (it1->getValue() != it2->getValue()) { + return false; + } + } + return true; + } + + bool compareOutputLayoutOverrides( + llvm::StringMap out1, + llvm::StringMap out2) { + // Check if the sizes of the two output layout overrides are the same. + if (out1.size() != out2.size()) { + return false; + } + llvm::StringMap::iterator it1; + for (it1 = out1.begin(); it1 != out1.end(); it1++) { + // Check if the two output layout overrides have the same keys. + llvm::StringMap::iterator it2 = + out2.find(it1->getKey()); + if (it2 == out2.end()) { + return false; + } + // Check if the two output layout overrides have the same values. + // The structure OutputLayoutOverrideParams has overloaded operators for + // == and !=, so we can compare the objects in this way. + if (it1->getValue() != it2->getValue()) { + return false; + } + } + return true; + } + + void TearDown() override {} +}; + +// Test the setEnableOptimizer method +TEST_F(TestOptimizerOverrideHandler, TestSetOptimizerPass) { + + optimizerOverridesHandler.setEnableOptimizer(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableOptimizer()); + + optimizerOverridesHandler.setEnableOptimizer(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableOptimizer()); +} + +// Test the setMemoryConfig method +TEST_F(TestOptimizerOverrideHandler, TestSetMemoryConfig) { + + optimizerOverridesHandler.setMemoryReconfig(true); + ASSERT_TRUE(optimizerOverridesHandler.getMemoryReconfig()); + + optimizerOverridesHandler.setMemoryReconfig(false); + ASSERT_FALSE(optimizerOverridesHandler.getMemoryReconfig()); +} + +// Test the setMemoryLayoutAnalysis method +TEST_F(TestOptimizerOverrideHandler, TestSetMemoryLayoutAnalysis) { + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysis()); + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysis()); +} + +// Test the setEnableMemoryLayoutAnalysisPolicy method +TEST_F(TestOptimizerOverrideHandler, TestSetEnableMemoryLayoutAnalysisPolicy) { + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysisPolicy(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysisPolicy()); + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysisPolicy(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysisPolicy()); +} + +// Test the setMemoryLayoutAnalysisPolicy method +TEST_F(TestOptimizerOverrideHandler, TestSetMemoryLayoutAnalysisPolicy) { + + optimizerOverridesHandler.setMemoryLayoutAnalysisPolicy( + mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding); + ASSERT_EQ(optimizerOverridesHandler.getMemoryLayoutAnalysisPolicy(), + mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding); + + optimizerOverridesHandler.setMemoryLayoutAnalysisPolicy( + mlir::tt::MemoryLayoutAnalysisPolicyType::L1Interleaved); + ASSERT_EQ(optimizerOverridesHandler.getMemoryLayoutAnalysisPolicy(), + mlir::tt::MemoryLayoutAnalysisPolicyType::L1Interleaved); +} + +// Test the setInputLayoutOverrides method +TEST_F(TestOptimizerOverrideHandler, TestSetInputLayoutOverrides) { + + llvm::StringMap inputLayoutOverrides = + createInputLayoutOverrides(); + + optimizerOverridesHandler.setInputLayoutOverrides(inputLayoutOverrides); + ASSERT_TRUE(compareInputLayoutOverrides( + optimizerOverridesHandler.getInputLayoutOverrides(), + inputLayoutOverrides)); +} + +// Test the setOutputLayoutOverrides method +TEST_F(TestOptimizerOverrideHandler, TestSetOutputLayoutOverrides) { + + llvm::StringMap outputLayoutOverrides = + createOutputLayoutOverrides(); + + optimizerOverridesHandler.setOutputLayoutOverrides(outputLayoutOverrides); + ASSERT_TRUE(compareOutputLayoutOverrides( + optimizerOverridesHandler.getOutputLayoutOverrides(), + outputLayoutOverrides)); +} + +// Test the addInputLayoutOverride method passing the whole object +TEST_F(TestOptimizerOverrideHandler, TestAddInputLayoutOverrideObject) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the first function, which takes the whole object as a + // parameter. + + llvm::StringMap inputLayoutOverrides = + createInputLayoutOverrides(); + + optimizerOverridesHandler.addInputLayoutOverride( + "input0", createInputLayoutOverrideParams()); + optimizerOverridesHandler.addInputLayoutOverride( + "input1", createInputLayoutOverrideParams()); + optimizerOverridesHandler.addInputLayoutOverride( + "input2", createInputLayoutOverrideParams()); + + ASSERT_TRUE(compareInputLayoutOverrides( + optimizerOverridesHandler.getInputLayoutOverrides(), + inputLayoutOverrides)); +} + +// Test the addInputLayoutOverride method passing the individual parameters +TEST_F(TestOptimizerOverrideHandler, TestAddInputLayoutOverrideParams) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the second function, which takes the individual parameters. + + llvm::StringMap inputLayoutOverrides = + createInputLayoutOverrides(); + + llvm::SmallVector operandIdxes1 = {0, 1}; + llvm::SmallVector operandIdxes2 = {0, 1}; + llvm::SmallVector operandIdxes3 = {0, 1}; + + optimizerOverridesHandler.addInputLayoutOverride("input0", operandIdxes1); + optimizerOverridesHandler.addInputLayoutOverride("input1", operandIdxes2); + optimizerOverridesHandler.addInputLayoutOverride("input2", operandIdxes3); + + ASSERT_TRUE(compareInputLayoutOverrides( + optimizerOverridesHandler.getInputLayoutOverrides(), + inputLayoutOverrides)); +} + +// Test the addOutputLayoutOverride method passing the whole object +TEST_F(TestOptimizerOverrideHandler, TestAddOutputLayoutOverrideObject) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the first function, which takes the whole object as a + // parameter. + + llvm::StringMap outputLayoutOverrides = + createOutputLayoutOverrides(); + + optimizerOverridesHandler.addOutputLayoutOverride( + "output0", createOutputLayoutOverrideParams_0()); + optimizerOverridesHandler.addOutputLayoutOverride( + "output1", createOutputLayoutOverrideParams_1()); + optimizerOverridesHandler.addOutputLayoutOverride( + "output2", createOutputLayoutOverrideParams_2()); + + ASSERT_TRUE(compareOutputLayoutOverrides( + optimizerOverridesHandler.getOutputLayoutOverrides(), + outputLayoutOverrides)); +} + +// Test the addOutputLayoutOverride method passing the individual parameters +TEST_F(TestOptimizerOverrideHandler, TestAddOutputLayoutOverrideParams) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the second function, which takes the individual parameters. + + llvm::StringMap outputLayoutOverrides = + createOutputLayoutOverrides(); + + llvm::SmallVector grid1 = {2, 2}; + llvm::SmallVector grid2 = {8, 4}; + llvm::SmallVector grid3 = {3, 6}; + + optimizerOverridesHandler.addOutputLayoutOverride( + "output0", grid1, BufferType::DRAM, TensorMemoryLayout::Interleaved, + Layout::Tile, mlir::tt::DataType::Float16); + optimizerOverridesHandler.addOutputLayoutOverride( + "output1", grid2, BufferType::L1, TensorMemoryLayout::BlockSharded, + Layout::RowMajor, mlir::tt::DataType::Float16); + optimizerOverridesHandler.addOutputLayoutOverride( + "output2", grid3, BufferType::SystemMemory, + TensorMemoryLayout::HeightSharded, Layout::Tile, + mlir::tt::DataType::Float16); + + ASSERT_TRUE(compareOutputLayoutOverrides( + optimizerOverridesHandler.getOutputLayoutOverrides(), + outputLayoutOverrides)); +} + +// Test the setSystemDescPath method +TEST_F(TestOptimizerOverrideHandler, TestSetSystemDescPath) { + + optimizerOverridesHandler.setSystemDescPath("system_desc_path"); + ASSERT_EQ(optimizerOverridesHandler.getSystemDescPath(), "system_desc_path"); +} + +// Test the setMaxLegalLayouts method +TEST_F(TestOptimizerOverrideHandler, TestSetMaxLegalLayouts) { + + optimizerOverridesHandler.setMaxLegalLayouts(10); + ASSERT_EQ(optimizerOverridesHandler.getMaxLegalLayouts(), 10); +} + +// Test the setMeshShape method +TEST_F(TestOptimizerOverrideHandler, TestSetMeshShape) { + + std::vector meshShape; + meshShape.push_back(1); + meshShape.push_back(2); + + optimizerOverridesHandler.setMeshShape(meshShape); + ASSERT_EQ(optimizerOverridesHandler.getMeshShape()[0], meshShape[0]); + ASSERT_EQ(optimizerOverridesHandler.getMeshShape()[1], meshShape[1]); +} + +// Test the toString method +TEST_F(TestOptimizerOverrideHandler, TestToString) { + + std::string options; + options += + "enable-optimizer=true "; // The optimizer pass is enabled by default. + options += "memreconfig-enabled=true "; + options += "memory-layout-analysis-enabled=true "; + options += "insert-memreconfig=add_0_1_2=0 "; + options += + "override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32"; + + llvm::SmallVector operandIdxes = {0}; + llvm::SmallVector grid = {1, 1}; + + optimizerOverridesHandler.setEnableOptimizer(true); + optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(true); + optimizerOverridesHandler.setMemoryReconfig(true); + optimizerOverridesHandler.addInputLayoutOverride("add_0_1_2", operandIdxes); + optimizerOverridesHandler.addOutputLayoutOverride( + "add_1_2", grid, BufferType::DRAM, TensorMemoryLayout::Interleaved, + Layout::RowMajor, mlir::tt::DataType::Float32); + + ASSERT_EQ(optimizerOverridesHandler.toString(), options); +} diff --git a/test/unittests/Optimizer/TestShardSolver.cpp b/test/unittests/Optimizer/TestShardSolver.cpp index c2f73b8008..c43eacce71 100644 --- a/test/unittests/Optimizer/TestShardSolver.cpp +++ b/test/unittests/Optimizer/TestShardSolver.cpp @@ -94,13 +94,15 @@ class ShardSolverBase : public ::testing::Test { &context, getTensorRankedType().getShape(), builder.getF32Type(), memorySpace, mlir::tt::GridAttr::get(&context, {gridWidth, gridHeight}), - tensorMemoryLayout)}; + mlir::tt::ttnn::TensorMemoryLayoutAttr::get(&context, + tensorMemoryLayout))}; } else { legalLayouts[op].push_back(TTNNLayoutAttr::get( &context, getTensorRankedType().getShape(), builder.getF32Type(), memorySpace, mlir::tt::GridAttr::get(&context, {gridWidth, gridHeight}), - tensorMemoryLayout)); + mlir::tt::ttnn::TensorMemoryLayoutAttr::get(&context, + tensorMemoryLayout))); } } diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index b1c679a507..8d66b68695 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -1,6 +1,6 @@ include(ExternalProject) -set(TT_METAL_VERSION "f16cadfabebd7654baef73e4ac2c3240b12b0d1d") +set(TT_METAL_VERSION "7768e89929915f40f7fdb45e352dd4b83f335168") if ("$ENV{ARCH_NAME}" STREQUAL "grayskull") set(ARCH_NAME "grayskull") @@ -22,12 +22,13 @@ set(TTMETAL_INCLUDE_DIRS ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/third_party/umd - ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/third_party/umd/device + ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/third_party/umd/device/api ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/hw/inc ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/hw/inc/${ARCH_NAME} ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/hw/inc/${ARCH_EXTRA_DIR} ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_metal/third_party/umd/src/firmware/riscv/${ARCH_NAME} ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/tt_eager + ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal-build/include ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/reflect/e75434c4c5f669e4a74e4d84e0a30d7249c1e66f ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/nanomsg/28cc32d5bdb6a858fe53b3ccf7e923957e53eada/include ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal/.cpmcache/fmt/73b5ec45edbd92babfd91c3777a9e1ab9cac8238/include @@ -39,6 +40,7 @@ set(TTMETAL_INCLUDE_DIRS set(TTMETAL_LIBRARY_DIR ${PROJECT_SOURCE_DIR}/third_party/tt-metal/src/tt-metal-build/lib) set(TTNN_LIBRARY_PATH ${TTMETAL_LIBRARY_DIR}/_ttnn.so) set(TTMETAL_LIBRARY_PATH ${TTMETAL_LIBRARY_DIR}/libtt_metal.so) +set(DEVICE_LIBRARY_PATH ${TTMETAL_LIBRARY_DIR}/libdevice.so) if (TT_RUNTIME_ENABLE_PERF_TRACE) set(TRACY_LIBRARY_PATH ${TTMETAL_LIBRARY_DIR}/libtracy.so) else() @@ -48,6 +50,7 @@ endif() set(TTMETAL_LIBRARY_DIR ${TTMETAL_LIBRARY_DIR} PARENT_SCOPE) set(TTNN_LIBRARY_PATH ${TTNN_LIBRARY_PATH} PARENT_SCOPE) set(TTMETAL_LIBRARY_PATH ${TTMETAL_LIBRARY_PATH} PARENT_SCOPE) +set(DEVICE_LIBRARY_PATH ${DEVICE_LIBRARY_PATH} PARENT_SCOPE) set(TRACY_LIBRARY_PATH ${TRACY_LIBRARY_PATH} PARENT_SCOPE) ExternalProject_Add( @@ -65,13 +68,17 @@ ExternalProject_Add( GIT_REPOSITORY https://github.com/tenstorrent/tt-metal.git GIT_TAG ${TT_METAL_VERSION} GIT_PROGRESS ON - BUILD_BYPRODUCTS ${TTNN_LIBRARY_PATH} ${TTMETAL_LIBRARY_PATH} ${TRACY_LIBRARY_PATH} + BUILD_BYPRODUCTS ${TTNN_LIBRARY_PATH} ${TTMETAL_LIBRARY_PATH} ${TRACY_LIBRARY_PATH} ${DEVICE_LIBRARY_PATH} ) +ExternalProject_Add_StepTargets(tt-metal download configure) +set_target_properties(tt-metal-download PROPERTIES EXCLUDE_FROM_ALL TRUE) +set_target_properties(tt-metal-configure PROPERTIES EXCLUDE_FROM_ALL TRUE) + set_target_properties(tt-metal PROPERTIES EXCLUDE_FROM_ALL TRUE) -list(APPEND library_names TTNN_LIBRARY TTMETAL_LIBRARY) -list(APPEND library_paths ${TTNN_LIBRARY_PATH} ${TTMETAL_LIBRARY_PATH}) +list(APPEND library_names TTNN_LIBRARY TTMETAL_LIBRARY DEVICE_LIBRARY) +list(APPEND library_paths ${TTNN_LIBRARY_PATH} ${TTMETAL_LIBRARY_PATH} ${DEVICE_LIBRARY_PATH}) if (TT_RUNTIME_ENABLE_PERF_TRACE) list(APPEND library_names TRACY_LIBRARY) diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index e558d0567e..e5a62f9c5a 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(ttmlir-opt) +add_subdirectory(ttmlir-lsp-server) add_subdirectory(ttmlir-translate) add_subdirectory(explorer) diff --git a/tools/explorer/CMakeLists.txt b/tools/explorer/CMakeLists.txt index 7ad0791b87..44613b2671 100644 --- a/tools/explorer/CMakeLists.txt +++ b/tools/explorer/CMakeLists.txt @@ -17,7 +17,7 @@ ExternalProject_Add( add_custom_target(explorer COMMENT "Building tt-explorer... ${TTMLIR_BIN_DIR}" - COMMAND pip install ${CMAKE_CURRENT_SOURCE_DIR}/tt_adapter + COMMAND pip install $<$:-e> ${CMAKE_CURRENT_SOURCE_DIR}/tt_adapter COMMAND pip install ${CMAKE_CURRENT_SOURCE_DIR}/model-explorer/src/model-explorer/src/server/package DEPENDS TTMLIRPythonModules model-explorer ttrt diff --git a/tools/explorer/tt_adapter/src/tt_adapter/main.py b/tools/explorer/tt_adapter/src/tt_adapter/main.py index 2bb3ece81a..3876a09112 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/main.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/main.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Dict import model_explorer -from . import ttir, runner, utils +from . import runner, utils, mlir import dataclasses import enum @@ -46,7 +46,7 @@ def convert( module = utils.parse_mlir_file(model_path) # Convert TTIR to Model Explorer Graphs and Display/Return - graph = ttir.ttir_to_graph(module) + graph = mlir.build_graph(module) return {"graphs": [graph]} def execute( @@ -70,9 +70,9 @@ def execute( memory_layout_analysis_enabled = False memory_layout_analysis_policy = None - ttnn_ir = self.model_runner.run( + perf_data = self.model_runner.run( model_path, memory_layout_analysis_enabled, memory_layout_analysis_policy ) # TODO(odjuricic, #933) Parse TTNN IR and return the post optimized graph. - return {"graphs": []} + return utils.to_adapter_format({"perf_data": perf_data}) diff --git a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py new file mode 100644 index 0000000000..e48eca4a8d --- /dev/null +++ b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py @@ -0,0 +1,599 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +# Utility library for parsing MLIR + +from collections import defaultdict +from model_explorer import graph_builder + +from ttmlir.dialects import tt, ttnn, ttir +from ttmlir import ir + + +def get_loc_str(loc): + try: + # Constant loc( at the start of the location and ) at the end. Can just strip these characters + loc = str(loc) + if loc.startswith("loc(") and loc.endswith(")"): + res = str(loc)[4:-1] + else: + res = loc # This is a fallback to just visualize / see what the loc is if not processable. + except: + res = "unknown" + return res + + +class AttrHandler: + """ + A class that handles parsing and registering handlers for MLIR attribute types. + """ + + ATTR_HANDLERS = {} + + @staticmethod + def default_parser(attr): + return [graph_builder.KeyValue(key=attr.name, value=str(attr.attr))] + + @staticmethod + def parse_attr(attr): + if attr.name in AttrHandler.ATTR_HANDLERS: + return AttrHandler.ATTR_HANDLERS[attr.name](attr.attr) + else: + # Unknown Attr Type, return default parser + return AttrHandler.default_parser(attr) + + @staticmethod + def register_handler(attr_name): + """ + Decorator function to register a handler for a specific attribute name. + + Usage: + + @AttrHandler.register_handler("attr_name") + def parse_attr_name(attr: ir.Attribute) -> List[graph_builder.KeyValue]: + pass + + registers a handler for any NamedAttribute present in the MLIR module with the name "attr_name". + + The handler itself is the function that is decorated with this decorator. It must follow the function signature of + `parse_attr_name` as shown above. + """ + + def decorator(handler): + AttrHandler.ATTR_HANDLERS[attr_name] = handler + return handler + + return decorator + + +@AttrHandler.register_handler("tt.device") +def parse_tt_device(attr): + device = tt.ir.DeviceAttr.maybe_downcast(attr) + result = [] + result.append( + graph_builder.KeyValue( + key="device_chip_ids", value=", ".join(map(str, device.chip_ids)) + ) + ) + result.append( + graph_builder.KeyValue( + key="device_grid_shape", value=str(device.grid_attr.shape) + ) + ) + if device.mesh_shape: + result.append( + graph_builder.KeyValue( + key="device_mesh_shape", value=str(device.mesh_shape) + ) + ) + result.append(graph_builder.KeyValue(key="device_l1_map", value=str(device.l1_map))) + result.append( + graph_builder.KeyValue(key="device_dram_map", value=str(device.dram_map)) + ) + return result + + +@AttrHandler.register_handler("tt.system_desc") +def parse_tt_system_desc(attr): + system_desc = tt.ir.SystemDescAttr.maybe_downcast(attr) + result = [] + for i, chip_desc, chip_coord, chip_capability in zip( + system_desc.chip_desc_indices, + system_desc.chip_descs, + system_desc.chip_coords, + system_desc.chip_capabilities, + ): + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-arch", value=str(tt.Arch(chip_desc.arch.arch_as_int)) + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-capability", + value=str(tt.ChipCapability(chip_capability.capability_as_int)), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-coord", + value="x".join( + map( + str, + (chip_coord.rack, chip_coord.shelf, chip_coord.y, chip_coord.x), + ) + ), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-dram-channel-size", + value=str(chip_desc.dram_channel_size), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-dram-unreserved-base", + value=str(chip_desc.dram_unreserved_base), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-dram-unreserved-end", + value=str(chip_desc.dram_unreserved_end), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-erisc-l1-unreserved-size", + value=str(chip_desc.erisc_l1_unreserved_base), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-grid", value="x".join(map(str, chip_desc.grid)) + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-l1-size", value=str(chip_desc.l1_size) + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-l1-unreserved-base", + value=str(chip_desc.l1_unreserved_base), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-noc-dram-address-align-bytes", + value=str(chip_desc.noc_dram_address_align_bytes), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-noc-l1-address-align-bytes", + value=str(chip_desc.noc_l1_address_align_bytes), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-num-cbs", value=str(chip_desc.num_cbs) + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-num-dram-channels", + value=str(chip_desc.num_dram_channels), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-pcie-address-align-bytes", + value=str(chip_desc.pcie_address_align_bytes), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-usable-dram-channel-size", + value=str(chip_desc.usable_dram_channel_size), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-usable-l1-size", value=str(chip_desc.usable_l1_size) + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-supported-data-types", + value=", ".join( + [ + str(tt.DataType(dt.data_type_as_int)) + for dt in chip_desc.supported_data_types + ] + ), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-supported-tile-sizes", + value=", ".join( + [ + "x".join(map(str, (tsize.y, tsize.x))) + for tsize in chip_desc.supported_tile_sizes + ] + ), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-dram-core-coords", + value=", ".join( + [ + "x".join(map(str, (coord.y, coord.x))) + for coord in chip_desc.chip_physical_cores.dram + ] + ), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-eth-core-coords", + value=", ".join( + [ + "x".join(map(str, (coord.y, coord.x))) + for coord in chip_desc.chip_physical_cores.eth + ] + ), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-eth-inactive-core-coords", + value=", ".join( + [ + "x".join(map(str, (coord.y, coord.x))) + for coord in chip_desc.chip_physical_cores.eth_inactive + ] + ), + ) + ) + result.append( + graph_builder.KeyValue( + key=f"chip#{i}-worker-core-coords", + value=", ".join( + [ + "x".join(map(str, (coord.y, coord.x))) + for coord in chip_desc.chip_physical_cores.worker + ] + ), + ) + ) + return result + + +@AttrHandler.register_handler("mesh_shape") +def parse_mesh_shape(attr): + mesh_shape = ttnn.ir.MeshShapeAttr.maybe_downcast(attr) + return [ + graph_builder.KeyValue( + key="mesh_shape", value="x".join(map(str, (mesh_shape.y, mesh_shape.x))) + ) + ] + + +@AttrHandler.register_handler("layout") +def parse_layout(attr): + # This is for parsing TTNN Layouts (Enum) + layout = ttnn.ir.LayoutAttr.maybe_downcast(attr) + return [graph_builder.KeyValue(key="layout", value=str(ttnn.Layout(layout.value)))] + + +@AttrHandler.register_handler("memory_config") +def parse_memory_config(attr): + memory_config = ttnn.ir.MemoryConfigAttr.maybe_downcast(attr) + result = [] + result.append( + graph_builder.KeyValue( + key="buffer-type", + value=str(ttnn.BufferType(memory_config.buffer_type.value)), + ) + ) + result.append( + graph_builder.KeyValue( + key="shard-shape", + value="x".join(map(str, memory_config.shard_spec.shard_shape.shape)), + ) + ) + result.append( + graph_builder.KeyValue( + key="tensor-memory-layout", + value=str( + ttnn.TensorMemoryLayout(memory_config.tensor_memory_layout.value) + ), + ) + ) + return result + + +@AttrHandler.register_handler("force") +def parse_force(attr): + return [graph_builder.KeyValue(key="force", value=str(attr.value))] + + +@AttrHandler.register_handler("dtype") +def parse_dtype(attr): + dtype = tt.ir.DataTypeAttr.maybe_downcast(attr) + return [ + graph_builder.KeyValue( + key="dtype", value=str(tt.DataType(dtype.data_type_as_int)) + ) + ] + + +@AttrHandler.register_handler("shape") +def parse_shape(attr): + shape = ttnn.ir.ShapeAttr.maybe_downcast(attr) + if not shape: + return [graph_builder.KeyValue(key="shape", value=str(attr))] + return [graph_builder.KeyValue(key="shape", value="x".join(map(str, shape.shape)))] + + +@AttrHandler.register_handler("operandSegmentSizes") +def parse_operandSegmentSizes(attr): + return [graph_builder.KeyValue(key="operandSegmentSizes", value=str(list(attr)))] + + +@AttrHandler.register_handler("dimension") +def parse_dimension(attr): + return [graph_builder.KeyValue(key="dimension", value=str(attr.value))] + + +@AttrHandler.register_handler("tt.layout") +def parse_tt_layout(attr): + layout = tt.ir.MetalLayoutAttr.maybe_downcast(attr) + result = [] + result.append(graph_builder.KeyValue(key="linear", value=str(layout.linear))) + result.append( + graph_builder.KeyValue( + key="memory_space", value=str(tt.MemorySpace(layout.memory_space_as_int)) + ) + ) + result.append( + graph_builder.KeyValue( + key="memory_layout", + value=str(tt.TensorMemoryLayout(layout.memory_layout_as_int)), + ) + ) + result.append( + graph_builder.KeyValue( + key="grid_shape", value="x".join(map(str, layout.grid_attr.shape)) + ) + ) + result.append( + graph_builder.KeyValue(key="memref_shape", value=str(layout.memref.shape)) + ) + result.append( + graph_builder.KeyValue(key="memref_rank", value=str(layout.memref.rank)) + ) + tile_type = tt.ir.TileType.maybe_downcast(layout.memref.element_type) + if tile_type is not None: + result.append( + graph_builder.KeyValue( + key="tile_datatype", value=str(tt.DataType(tile_type.data_type_as_int)) + ) + ) + result.append( + graph_builder.KeyValue( + key="tile_shape", value="x".join(map(str, tile_type.shape)) + ) + ) + return result + + +@AttrHandler.register_handler("ttnn_layout") +def parse_ttnn_ttnn_layout(attr): + layout = ttnn.ir.TTNNLayoutAttr.maybe_downcast(attr) + result = [] + result.append(graph_builder.KeyValue(key="linear", value=str(layout.linear))) + result.append( + graph_builder.KeyValue( + key="memory_layout", + value=str(ttnn.TensorMemoryLayout(layout.memory_layout_as_int)), + ) + ) + result.append( + graph_builder.KeyValue( + key="grid_shape", value="x".join(map(str, layout.grid_attr.shape)) + ) + ) + result.append( + graph_builder.KeyValue(key="memref_shape", value=str(layout.memref.shape)) + ) + result.append( + graph_builder.KeyValue(key="memref_rank", value=str(layout.memref.rank)) + ) + buffer_attr = ttnn.ir.BufferTypeAttr.maybe_downcast(layout.memref.memory_space) + result.append( + graph_builder.KeyValue( + key="memref_memory_space", value=str(ttnn.BufferType(buffer_attr.value)) + ) + ) + return result + + +class OpHandler: + def __init__(self, op): + self.op = op + + def get_id(self, names: defaultdict): + name = get_loc_str(self.op.location) + name_num = names[name] + id = name + "__" + str(name_num) + names[name] += 1 + return id + + def get_namespace(self, parent_op=None): + op = self.op if not parent_op else parent_op + name = get_loc_str(op.location) + if op.parent and op.parent.name != "builtin.module": + return self.get_namespace(op.parent) + "/" + name + return name + + def get_attributes(self): + # Parse Op Attributes themselves + result = [] + for attr in self.op.attributes: + result.extend(AttrHandler.parse_attr(attr)) + return result + + def make_graph_node(self, name_dict): + return graph_builder.GraphNode( + id=self.get_id(name_dict), + label=self.op.name, + namespace=self.get_namespace(), + attrs=self.get_attributes(), + ) + + def make_constant_node(self, name_dict, constant_name): + return graph_builder.GraphNode( + id=self.get_id(name_dict), + label=constant_name, + namespace=self.get_namespace(), + ) + + +EMPTY_OPS = [ + "ttnn.empty", + "tensor.empty", +] + +FILTERED_OPS = [ + "ttnn.deallocate", + "ttnn.get_device", +] + + +def get_locs(module): + name_dict = defaultdict(int) + + for op in module.body.operations: + for region in op.regions: + for block in region.blocks: + for op in block.operations: + op = OpHandler(op) + _id = op.get_id(name_dict) + # This will now populate name_dict with all of the locations that are relevant + + # The keys will be all the unique locations, and the values will be the number of times that location appears + return name_dict + + +def build_graph(module): + name_dict = defaultdict(int) + output_connections = defaultdict(int) + graph = graph_builder.Graph(id="tt-graph") + + op_to_graph_node = {} + + module_op = OpHandler(module.operation) + module_attrs = module_op.get_attributes() + module_attrs = dict((attr.key, attr.value) for attr in module_attrs) + # Add module attributes to the graph as "namespace attributes" + group_node_attrs = {} + group_node_attrs[module_op.get_namespace()] = module_attrs + + for op in module.body.operations: + append_later = [] + for region in op.regions: + for block in region.blocks: + for op in block.operations: + # Create all the nodes and constants in the first pass. + operation = OpHandler(op) + graph_node = operation.make_graph_node(name_dict) + + if op.name in EMPTY_OPS: + append_later.append(graph_node) + elif op.name not in FILTERED_OPS: + graph.nodes.append(graph_node) + + op_to_graph_node[op] = graph_node + + for operand in op.operands: + if isinstance(operand, ir.Value) and not isinstance( + operand.owner, ir.Operation + ): + # If the owner is not an op, then it is a constant provided from the toplevel FuncOp. + + # This is a constant and we need to create a node for it. + operand_node = operation.make_constant_node( + name_dict, operand.get_name() + ) + graph.nodes.append(operand_node) + op_to_graph_node[operand] = operand_node + + # This puts the node at the far right when viewing which is a bit more consistant with it being the last operand. + for node in append_later: + graph.nodes.append(node) + + for op in block.operations: + # Create all the edges in the second pass. + for operand_index, operand in enumerate(op.operands): + if operand.owner == block: + source_node = op_to_graph_node[operand] + else: + source_node = op_to_graph_node[operand.owner] + + target_node = op_to_graph_node[op] + + target_node.incomingEdges.append( + graph_builder.IncomingEdge( + sourceNodeId=source_node.id, + sourceNodeOutputId=output_connections[source_node.id], + targetNodeInputId=operand_index, + ) + ) + + output_attrs = [] + if isinstance(operand.type, ir.RankedTensorType): + output_attrs = [ + graph_builder.KeyValue( + key="shape", value=str(operand.type.shape) + ), + graph_builder.KeyValue( + key="dtype", value=str(operand.type.element_type) + ), + graph_builder.KeyValue( + key="rank", value=str(operand.type.rank) + ), + ] + if hasattr(operand.type, "encoding") and operand.type.encoding: + if "ttnn_layout" in str(operand.type.encoding): + output_attrs.extend( + AttrHandler.parse_attr( + operand.type.encoding.get_named("ttnn_layout") + ) + ) + else: + # Parse as a standard layout + output_attrs.extend( + AttrHandler.parse_attr( + operand.type.encoding.get_named("tt.layout") + ) + ) + source_node.outputsMetadata.append( + graph_builder.MetadataItem( + id=str(output_connections[source_node.id]), + attrs=[ + graph_builder.KeyValue( + key="__tensor_tag", value=target_node.label + ), + ] + + output_attrs, + ) + ) + output_connections[source_node.id] += 1 + graph.groupNodeAttributes = group_node_attrs + return graph diff --git a/tools/explorer/tt_adapter/src/tt_adapter/runner.py b/tools/explorer/tt_adapter/src/tt_adapter/runner.py index 65da2ec2b7..205944acd4 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/runner.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/runner.py @@ -9,8 +9,9 @@ # os.environ["TTRT_LOGGER_LEVEL"] = "ERROR" from ttrt import API as ttrt import ttmlir.passes -from . import utils +from . import utils, mlir import pandas as pd +from model_explorer import node_data_builder class ModelRunner: @@ -69,6 +70,9 @@ def run( module = utils.parse_mlir_file(model_path) + # Collect unique locations + name_dict = mlir.get_locs(module) + try: print("Running MLIR compile: TTIR to TTNN Backend Pipeline") print("With options: ", options_string) @@ -131,8 +135,34 @@ def run( "DEVICE FW DURATION [ns]", "CORE COUNT", "OUTPUT_0_MEMORY", + "LOC", ] perf = perf[columns] print(perf) - print("Total device duration: ", perf["DEVICE FW DURATION [ns]"].sum(), "ns") + print(f"Total device duration: {perf['DEVICE FW DURATION [ns]'].sum()}ns") + + # Create the node_data type here + timing_data = list(zip(perf["LOC"], perf["DEVICE FW DURATION [ns]"])) + results = {} + for loc, duration in timing_data: + loc = mlir.get_loc_str(loc).replace("'", '"') + if loc in name_dict: + for i in range(name_dict[loc]): + results[f"{loc}__{i}"] = node_data_builder.NodeDataResult( + value=duration + ) + else: + print( + f"Location {loc} not found in graph, ops data for this op was not reported." + ) + + gradient = [ + node_data_builder.GradientItem(stop=0, bgColor="yellow"), + node_data_builder.GradientItem(stop=1, bgColor="red"), + ] + + data = node_data_builder.GraphNodeData(results=results, gradient=gradient) + + res = node_data_builder.ModelNodeData(graphsData={"tt-graph": data}) + return res diff --git a/tools/explorer/tt_adapter/src/tt_adapter/ttir.py b/tools/explorer/tt_adapter/src/tt_adapter/ttir.py deleted file mode 100644 index 76cd470b0f..0000000000 --- a/tools/explorer/tt_adapter/src/tt_adapter/ttir.py +++ /dev/null @@ -1,149 +0,0 @@ -# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 -# Library to manipulate TTIR Modules - -from model_explorer import graph_builder -from ttmlir.dialects import tt, ttir, ttkernel -from collections import defaultdict - - -def get_loc_str(loc): - # TODO(odjuricic) Need to expose this in python bindings, if possible. - try: - res = str(loc).split('"')[1] - except: - res = "unknown" - return res - - -def create_id(op, name_dict): - name = get_loc_str(op.location) - name_num = name_dict[name] - id = name + "__" + str(name_num) - name_dict[name] += 1 - return id - - -def get_attrs(op): - result = [] - for attr in op.attributes: - result.append(graph_builder.KeyValue(key=attr.name, value=str(attr.attr))) - return result - - -def create_namespace(op): - name = get_loc_str(op.location) - if op.parent and op.parent.name != "builtin.module": - return create_namespace(op.parent) + "/" + name - return name - - -def get_layout_attrs(tensor): - attrs = [ - graph_builder.KeyValue(key="shape", value=str(tensor.type.shape)), - graph_builder.KeyValue( - key="element_type", - value=str(tensor.type.element_type), - ), - graph_builder.KeyValue(key="rank", value=str(tensor.type.rank)), - ] - - if hasattr(tensor.type, "encoding") and tensor.type.encoding: - layout = tt.ir.LayoutAttr.getLayout(tensor.type) - attrs.extend( - [ - graph_builder.KeyValue( - key="Memory Space", - value=str(tt.MemorySpace(layout.memory_space_as_int)), - ), - graph_builder.KeyValue( - key="Memory Layout", - value=str(tt.TensorMemoryLayout(layout.memory_layout_as_int)), - ), - graph_builder.KeyValue( - key="Grid Shape", - value=str(list(layout.grid_attr.shape)), - ), - ] - ) - - return attrs - - -def ttir_to_graph(module): - # Can assume that to-layout pass has already been run on the module. - name_dict = defaultdict(int) - output_connections = defaultdict(int) - graph = graph_builder.Graph(id="ttir-graph") - - op_to_graph_node = dict() - - for op in module.body.operations: - append_later = [] - for region in op.regions: - for block in region.blocks: - for op in block.operations: - # Create all the nodes and constants in the first pass. - graph_node = graph_builder.GraphNode( - id=create_id(op, name_dict), - label=op.name, - namespace=create_namespace(op), - attrs=get_attrs(op), - ) - - if op.name == "tensor.empty": - append_later.append(graph_node) - else: - graph.nodes.append(graph_node) - - op_to_graph_node[op] = graph_node - - for operand in op.operands: - if operand.owner == block and operand not in op_to_graph_node: - # This is a constant and we need to create a node for it. - operand_node = graph_builder.GraphNode( - id=create_id(op, name_dict), - label=operand.get_name(), - namespace=create_namespace(op), - ) - graph.nodes.append(operand_node) - op_to_graph_node[operand] = operand_node - - # This puts the node at the far right when viewing which is a bit more consistant with it being the last operand. - for node in append_later: - graph.nodes.append(node) - - for op in block.operations: - # Create all the edges in the second pass. - for operand_index, operand in enumerate(op.operands): - if operand.owner == block: - source_node = op_to_graph_node[operand] - else: - source_node = op_to_graph_node[operand.owner] - - target_node = op_to_graph_node[op] - - target_node.incomingEdges.append( - graph_builder.IncomingEdge( - sourceNodeId=source_node.id, - sourceNodeOutputId=output_connections[source_node.id], - targetNodeInputId=operand_index, - ) - ) - - output_attrs = get_layout_attrs(operand) - source_node.outputsMetadata.append( - graph_builder.MetadataItem( - id=str(output_connections[source_node.id]), - attrs=[ - graph_builder.KeyValue( - key="__tensor_tag", value=target_node.label - ), - ] - + output_attrs, - ) - ) - output_connections[source_node.id] += 1 - - return graph diff --git a/tools/explorer/tt_adapter/src/tt_adapter/utils.py b/tools/explorer/tt_adapter/src/tt_adapter/utils.py index fe68d89ac5..4b404a204b 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/utils.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/utils.py @@ -2,12 +2,21 @@ # # SPDX-License-Identifier: Apache-2.0 import ttmlir +from dataclasses import make_dataclass def parse_mlir_file(model_path): with ttmlir.ir.Context() as ctx, open(model_path, "r") as model_file: - ttmlir.dialects.ttkernel.register_dialect(ctx) ttmlir.dialects.ttir.register_dialect(ctx) ttmlir.dialects.tt.register_dialect(ctx) - module = ttmlir.ir.Module.parse("".join(model_file.readlines()), ctx) + ttmlir.dialects.ttnn.register_dialect(ctx) + module = ttmlir.ir.Module.parse(model_file.read(), ctx) return module + + +def to_dataclass(obj: dict, dc_name: str = "tempClass"): + return make_dataclass(dc_name, ((k, type(v)) for k, v in obj.items()))(**obj) + + +def to_adapter_format(obj: dict): + return {"graphs": [to_dataclass(obj)]} diff --git a/tools/ttmlir-lsp-server/CMakeLists.txt b/tools/ttmlir-lsp-server/CMakeLists.txt new file mode 100644 index 0000000000..1dd058a715 --- /dev/null +++ b/tools/ttmlir-lsp-server/CMakeLists.txt @@ -0,0 +1,18 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) + +set(LIBS ${dialect_libs} ${conversion_libs} ${extension_libs} + MLIROptLib + MLIRTargetCpp + TTMLIRStatic + MLIRLspServerLib +) + +add_llvm_executable(ttmlir-lsp-server ttmlir-lsp-server.cpp DISABLE_LLVM_LINK_LLVM_DYLIB) +llvm_update_compile_flags(ttmlir-lsp-server) +target_link_libraries(ttmlir-lsp-server PRIVATE ${LIBS}) + +mlir_check_all_link_libraries(ttmlir-lsp-server) + +install(TARGETS ttmlir-lsp-server DESTINATION ${CMAKE_INSTALL_BINDIR} COMPONENT Test EXCLUDE_FROM_ALL) diff --git a/tools/ttmlir-lsp-server/ttmlir-lsp-server.cpp b/tools/ttmlir-lsp-server/ttmlir-lsp-server.cpp new file mode 100644 index 0000000000..d23425e968 --- /dev/null +++ b/tools/ttmlir-lsp-server/ttmlir-lsp-server.cpp @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "mlir/InitAllDialects.h" +#include "ttmlir/RegisterAll.h" + +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + mlir::tt::registerAllDialects(registry); + + return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); +} diff --git a/tools/ttnn-standalone/CMakeLists.txt b/tools/ttnn-standalone/CMakeLists.txt index 7de585bf0a..23c78c7ca9 100644 --- a/tools/ttnn-standalone/CMakeLists.txt +++ b/tools/ttnn-standalone/CMakeLists.txt @@ -91,13 +91,6 @@ set(LINK_LIBS yaml-cpp pthread - # The below libs have been added to tt-metal repo at some point, but are not - # currently needed by the targets here - leaving them commented here for - # reference - # - # nng - # uv - # TTNN # _ttnn # Why doesn't this work? $ENV{TT_METAL_HOME}-build/lib/_ttnn.so diff --git a/tools/ttnn-standalone/README.md b/tools/ttnn-standalone/README.md index 816cfe1cf5..619e52d1c3 100644 --- a/tools/ttnn-standalone/README.md +++ b/tools/ttnn-standalone/README.md @@ -14,7 +14,7 @@ Third party ML models (PyTorch, Jax, ONNX, ...) can be compiled to a set of TTNN ```bash # Compile a model to C++ code -./build/bin/ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-layout --convert-ttir-to-ttnn --ttnn-decompose-layouts --ttnn-deallocate --convert-ttnn-to-emitc test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir | ./build/bin/ttmlir-translate --mlir-to-cpp -allow-unregistered-dialect +./build/bin/ttmlir-opt --ttir-to-emitc-pipeline test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir | ./build/bin/ttmlir-translate --mlir-to-cpp # Copy paste the generated function into `ttnn-standalone.cpp`. diff --git a/tools/ttnn-standalone/ttnn-standalone.cpp b/tools/ttnn-standalone/ttnn-standalone.cpp index dff9afff43..0dee60f134 100644 --- a/tools/ttnn-standalone/ttnn-standalone.cpp +++ b/tools/ttnn-standalone/ttnn-standalone.cpp @@ -5,11 +5,9 @@ #include "ttnn-precompiled.hpp" // To generate forward function, run: -// ./build/bin/ttmlir-opt --ttir-load-system-desc --ttir-implicit-device -// --ttir-layout --convert-ttir-to-ttnn --ttnn-decompose-layouts -// --ttnn-deallocate --convert-ttnn-to-emitc +// ./build/bin/ttmlir-opt --ttir-to-emitc-pipeline // test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir | ./build/bin/ttmlir-translate -// --mlir-to-cpp -allow-unregistered-dialect +// --mlir-to-cpp ttnn::Tensor forward(ttnn::Tensor v1, ttnn::Tensor v2) { ttnn::Device *v3 = ttnn::DeviceGetter::getInstance();