-
Notifications
You must be signed in to change notification settings - Fork 539
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tosa] : Handle CHW input for avgpool2d. (#4042)
This PR fixes two issues: 1. Avg(Max)Pool2d allows (C, H, W) as input which causes a failure when creating `tosa.avg_pool2d` or `tosa.max_pool2d` as those ops expects 4D (N, H, W, C) tensor. Fix is to add a N=1 dimension before creating the tosa ops. 2. Avg(Max)Pool2d also allows kernel/stride to be specified as a tuple of single int, in which case the value is repeated for both H and W dims. This is currently not handled as well causing a segv when trying to access `kernel[1]/stride[1]`. Fix is to expand `kernel/stride` to be size 2 by repeating the first element.
- Loading branch information
Showing
4 changed files
with
189 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters