diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index 1d3960b43..6e8cee3bb 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -22,6 +22,45 @@ concurrency: cancel-in-progress: true jobs: + test_punet: + name: "Integration Tests - punet" + runs-on: nodai-amdgpu-mi250-x86-64 + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: 3.11 + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Update to the latest iree packages. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler iree-base-runtime --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + - name: Run punet tests + run: | + pytest -v sharktank/ -m model_punet + test: name: "Unit Tests and Type Checking" strategy: diff --git a/sharktank/integration/models/punet/integration_test.py b/sharktank/integration/models/punet/integration_test.py index 182b37a50..45af24004 100644 --- a/sharktank/integration/models/punet/integration_test.py +++ b/sharktank/integration/models/punet/integration_test.py @@ -89,12 +89,13 @@ def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir): def sdxl_int8_base_files(): from huggingface_hub import hf_hub_download - REPO_ID = "amd-shark/sdxl-quant-models" - REVISION = "942e771bf0c2657a8b33380103d04747a75dfa4a" + REPO_ID = "amd-shark/sdxl-quant-int8" + SUBFOLDER = "mi300_all_sym_8_step14_fp32" + REVISION = "efda8afb35fd72c1769e02370b320b1011622958" def download(filename): return hf_hub_download( - repo_id=REPO_ID, subfolder="unet/int8", filename=filename, revision=REVISION + repo_id=REPO_ID, subfolder=SUBFOLDER, filename=filename, revision=REVISION ) return { diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index b679dccde..acd9b8a37 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -31,9 +31,8 @@ class LinearLayer(ThetaLayer): x = x * premul_input matmul(x, weight.T) + bias - fake_quant exists to allow export without adding dequant ops. - when fake_quant is True, the op will in quant dequant fashion. - When false, it will keep quantized types. + fake quant only exists in order to allow for q_input to act as qdq. + when fake quant is false, q_input will quantize normally. ``` """ @@ -43,7 +42,7 @@ def __init__( *, weight_name: str = "weight", bias_name: str = "bias", - fake_quant: bool = True, + fake_quant: bool = False, ): super().__init__(theta) self._simulate_native_quant = True @@ -74,21 +73,23 @@ def forward(self, x): x = q_input.quantize(x) if self.fake_quant: x = x.unpack().dequant() - elif qdq_input is not None and self.fake_quant: + + elif qdq_input is not None: x = qdq_input.quantize(x).unpack().dequant() y = ops.linear(x, weight, bias) # Unconditionally dequantize. - if isinstance(y, QuantizedTensor) and not self.fake_quant: + if isinstance(y, QuantizedTensor): y = y.unpack().dequant() # Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32. # We can truncate to fp16 in iree, so we do a cast here # to account for this in the IR. This is may not be the right # level to do this, but for now its here. - if not self.fake_quant and y.dtype == torch.float8_e4m3fnuz: - y = ops.to(y, torch.float16) - return y - if qdq_output is not None and self.fake_quant: + if not isinstance(y, QuantizedTensor): + if y.dtype == torch.float8_e4m3fnuz: + y = ops.to(y, torch.float16) + return y + if qdq_output is not None: y = qdq_output.quantize(y).unpack().dequant() return y