Skip to content

Commit 8aa2667

Browse files
add bf16 for Tile CUDA executor (#20854)
### Description add bf16 for Tile CUDA executor ### Motivation and Context required change to support phimm model for ORT training
1 parent 0babc33 commit 8aa2667

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

docs/OperatorKernels.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ Do not modify directly.*
827827
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
828828
|ThresholdedRelu|*in* X:**T**<br> *out* Y:**T**|10+|**T** = tensor(double), tensor(float), tensor(float16)|
829829
|||1+|**T** = tensor(double), tensor(float), tensor(float16)|
830-
|Tile|*in* input:**T**<br> *in* repeats:**T1**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *in* tiles:**T**<br> *in* axis:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
830+
|Tile|*in* input:**T**<br> *in* repeats:**T1**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *in* tiles:**T**<br> *in* axis:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
831831
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
832832
|TopK|*in* X:**T**<br> *in* K:**tensor(int64)**<br> *out* Values:**T**<br> *out* Indices:**I**<br><br>or<br><br>*in* X:**T**<br> *out* Values:**T**<br> *out* Indices:**I**|11+|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
833833
|||10|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|

onnxruntime/core/providers/cuda/tensor/tile.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ ONNX_OPERATOR_KERNEL_EX(
3636
DataTypeImpl::GetTensorType<double>(),
3737
DataTypeImpl::GetTensorType<int32_t>(),
3838
DataTypeImpl::GetTensorType<int64_t>(),
39-
DataTypeImpl::GetTensorType<MLFloat16>()})
39+
DataTypeImpl::GetTensorType<MLFloat16>(),
40+
DataTypeImpl::GetTensorType<BFloat16>()})
4041
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
4142
Tile);
4243

0 commit comments

Comments
 (0)