-
Notifications
You must be signed in to change notification settings - Fork 521
RFC: Specialize for non-mixed-dtype in elementwise_util #9388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
swolchok
wants to merge
20
commits into
gh/swolchok/382/head
Choose a base branch
from
gh/swolchok/383/head
base: gh/swolchok/382/head
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+127
−19
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
fd62a07
Update
swolchok edd45fb
Update
swolchok 4c4add0
Update
swolchok 8782a90
Update
swolchok 75f8970
Update
swolchok 2d19e75
Update
swolchok 5348a92
Update
swolchok 001d72c
Update
swolchok e49080d
Update
swolchok 44ee51a
Update
swolchok f934bc0
Update
swolchok 3a74f25
Update
swolchok 2242f1e
Update
swolchok 42623bb
Update
swolchok 39610ad
Update
swolchok ff2c358
Update
swolchok 754dba4
Update
swolchok 946f2e0
Update
swolchok de9d52f
Update
swolchok 85451ea
Update
swolchok File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,6 +51,44 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) { | |
} | ||
|
||
namespace internal { | ||
template < | ||
typename CTYPE_COMPUTE, | ||
typename CTYPE_OUT, | ||
typename Op, | ||
typename... Args> | ||
inline void dtype_specialized_elementwise_fn_impl( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& out, | ||
Args... inputs) { | ||
constexpr auto kNumInputs = sizeof...(inputs); | ||
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMPUTE)) && ...)); | ||
|
||
::executorch::extension::parallel_for( | ||
0, | ||
out.numel(), | ||
::executorch::extension::internal::GRAIN_SIZE, | ||
[&](const auto begin, const auto end) { | ||
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = { | ||
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...}; | ||
|
||
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>(); | ||
|
||
const auto range = | ||
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...); | ||
auto begin_it = range.begin(); | ||
begin_it += begin; | ||
for (; (*begin_it)[0] < end; ++begin_it) { | ||
const auto& indexes = *begin_it; | ||
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs; | ||
for (const auto idx : c10::irange(kNumInputs)) { | ||
loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]]; | ||
} | ||
data_out[indexes[0]] = std::apply(compute_fun, loaded_inputs); | ||
} | ||
}); | ||
} | ||
|
||
template <typename CTYPE_COMPUTE, typename Op, typename... Args> | ||
inline bool validate_elementwise_fn_inputs( | ||
const Op& compute_fun, | ||
|
@@ -81,18 +119,12 @@ template < | |
const char* op_name, | ||
typename Op, | ||
typename... Args> | ||
inline void apply_elementwise_fn( | ||
inline void apply_elementwise_fn_generic_impl( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& out, | ||
SupportedTensorDtypes out_dtypes, | ||
Args... inputs) { | ||
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>( | ||
compute_fun, ctx, out, out_dtypes, inputs...); | ||
if (!inputs_valid) { | ||
return; | ||
} | ||
|
||
constexpr auto kNumInputs = sizeof...(inputs); | ||
|
||
struct InputInfo { | ||
|
@@ -138,6 +170,63 @@ inline void apply_elementwise_fn( | |
}); | ||
} | ||
|
||
template < | ||
typename CTYPE_COMPUTE, | ||
const char* op_name, | ||
typename Op, | ||
typename... Args> | ||
inline void apply_elementwise_fn_runtime_out_dtypes( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& out, | ||
SupportedTensorDtypes out_dtypes, | ||
Args... inputs) { | ||
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>( | ||
compute_fun, ctx, out, out_dtypes, inputs...); | ||
if (!inputs_valid) { | ||
return; | ||
} | ||
|
||
apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>( | ||
compute_fun, ctx, out, out_dtypes, inputs...); | ||
} | ||
|
||
template < | ||
typename CTYPE_COMPUTE, | ||
const char* op_name, | ||
SupportedTensorDtypes out_dtypes, | ||
typename Op, | ||
typename... Args> | ||
inline void apply_elementwise_fn( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& out, | ||
Args... inputs) { | ||
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>( | ||
compute_fun, ctx, out, out_dtypes, inputs...); | ||
if (!inputs_valid) { | ||
return; | ||
} | ||
|
||
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value; | ||
const bool all_inputs_compute_dtype = | ||
((inputs.first->scalar_type() == compute_type) && ...); | ||
|
||
constexpr ScalarType out_specialized_scalar_type = | ||
specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here, |
||
if (all_inputs_compute_dtype && | ||
out.scalar_type() == out_specialized_scalar_type) { | ||
using CTYPE_OUT = | ||
typename ScalarTypeToCppType<out_specialized_scalar_type>::type; | ||
dtype_specialized_elementwise_fn_impl<CTYPE_COMPUTE, CTYPE_OUT>( | ||
compute_fun, ctx, out, inputs...); | ||
return; | ||
} | ||
|
||
apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>( | ||
compute_fun, ctx, out, out_dtypes, inputs...); | ||
} | ||
|
||
/// DEPRECATED: prefer the variant with out_dtypes in the template argument. | ||
template <typename CTYPE_COMPUTE, const char* op_name, typename Op> | ||
inline void apply_unitensor_elementwise_fn( | ||
|
@@ -147,7 +236,7 @@ inline void apply_unitensor_elementwise_fn( | |
SupportedTensorDtypes a_dtypes, | ||
const Tensor& out, | ||
SupportedTensorDtypes out_dtypes) { | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>( | ||
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>( | ||
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes)); | ||
} | ||
|
||
|
@@ -162,8 +251,8 @@ inline void apply_unitensor_elementwise_fn( | |
const Tensor& a, | ||
SupportedTensorDtypes a_dtypes, | ||
const Tensor& out) { | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>( | ||
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes)); | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>( | ||
compute_fun, ctx, out, std::make_pair(&a, a_dtypes)); | ||
} | ||
|
||
/** | ||
|
@@ -179,7 +268,7 @@ inline void apply_bitensor_elementwise_fn( | |
SupportedTensorDtypes b_dtypes, | ||
const Tensor& out, | ||
SupportedTensorDtypes out_dtypes) { | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>( | ||
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>( | ||
compute_fun, | ||
ctx, | ||
out, | ||
|
@@ -206,11 +295,10 @@ inline void apply_bitensor_elementwise_fn( | |
const Tensor& b, | ||
SupportedTensorDtypes b_dtypes, | ||
const Tensor& out) { | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>( | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>( | ||
compute_fun, | ||
ctx, | ||
out, | ||
out_dtypes, | ||
std::make_pair(&a, a_dtypes), | ||
std::make_pair(&b, b_dtypes)); | ||
} | ||
|
@@ -230,7 +318,7 @@ inline void apply_tritensor_elementwise_fn( | |
SupportedTensorDtypes c_dtypes, | ||
const Tensor& out, | ||
SupportedTensorDtypes out_dtypes) { | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>( | ||
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>( | ||
compute_fun, | ||
ctx, | ||
out, | ||
|
@@ -275,11 +363,10 @@ inline void apply_tritensor_elementwise_fn( | |
const Tensor& c, | ||
SupportedTensorDtypes c_dtypes, | ||
const Tensor& out) { | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>( | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>( | ||
compute_fun, | ||
ctx, | ||
out, | ||
out_dtypes, | ||
std::make_pair(&a, a_dtypes), | ||
std::make_pair(&b, b_dtypes), | ||
std::make_pair(&c, c_dtypes)); | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be
CTYPE_COMPUTE
, right?