Skip to content

Commit 26fd8ec

Browse files
chyomin06fracape
authored andcommitted
Add ScaleSpaceFlow network
1 parent 9eb2aeb commit 26fd8ec

19 files changed

+365
-33
lines changed

.clang-format

+4-5
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ AlwaysBreakBeforeMultilineStrings: false
2323
AlwaysBreakTemplateDeclarations: MultiLine
2424
BinPackArguments: true
2525
BinPackParameters: true
26-
BraceWrapping:
26+
BraceWrapping:
2727
AfterCaseLabel: false
2828
AfterClass: false
2929
AfterControlStatement: false
@@ -60,12 +60,12 @@ DerivePointerAlignment: false
6060
DisableFormat: false
6161
ExperimentalAutoDetectBinPacking: false
6262
FixNamespaceComments: true
63-
ForEachMacros:
63+
ForEachMacros:
6464
- foreach
6565
- Q_FOREACH
6666
- BOOST_FOREACH
6767
IncludeBlocks: Preserve
68-
IncludeCategories:
68+
IncludeCategories:
6969
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
7070
Priority: 2
7171
- Regex: '^(<|"(gtest|gmock|isl|json)/)'
@@ -117,10 +117,9 @@ SpacesInCStyleCastParentheses: false
117117
SpacesInParentheses: false
118118
SpacesInSquareBrackets: false
119119
Standard: Cpp11
120-
StatementMacros:
120+
StatementMacros:
121121
- Q_UNUSED
122122
- QT_REQUIRE_VERSION
123123
TabWidth: 8
124124
UseTab: Never
125125
...
126-

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ per-file-ignores =
1111

1212
max-line-length = 88
1313

14-
# maximum McCabe complexity
14+
# maximum McCabe complexity
1515
max-complexity = 12
1616

1717
exclude =

.gitignore

+3-3
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,11 @@ ipython_config.py
9696
.python-version
9797

9898
# pipenv
99-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in
99+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in
100100
# version control.
101-
# However, in case of collaboration, if having platform-specific dependencies
101+
# However, in case of collaboration, if having platform-specific dependencies
102102
# or dependencies
103-
# having no cross-platform support, pipenv may install dependencies that don’t
103+
# having no cross-platform support, pipenv may install dependencies that don’t
104104
# work, or not
105105
# install all needed dependencies.
106106
#Pipfile.lock

compressai/layers/layers.py

+49
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.nn as nn
1919

2020
from torch import Tensor
21+
from torch.autograd import Function
2122

2223
from .gdn import GDN
2324

@@ -29,6 +30,7 @@
2930
"ResidualBlockWithStride",
3031
"conv3x3",
3132
"subpel_conv3x3",
33+
"qrelu"
3234
]
3335

3436

@@ -225,3 +227,50 @@ def forward(self, x: Tensor) -> Tensor:
225227
out = a * torch.sigmoid(b)
226228
out += identity
227229
return out
230+
231+
class qrelu(Function):
232+
"""QReLU
233+
234+
Clamping input with given bit-depth range.
235+
Suppose that input data presents integer through the integer networks
236+
otherwise fraction values of input bypass without rounding operation for floating point networks.
237+
238+
Pre-computed scale with gamma function is used for backward computation.
239+
240+
More details can be found `"Integer networks for data compression with latent-variable models"
241+
242+
#_______`_, Balle et al. in 2019
243+
244+
Args:
245+
input : a tensor data
246+
bit_depth : bit-depth for clamping input
247+
beta :
248+
"""
249+
250+
@staticmethod
251+
def forward(ctx, input, bit_depth, beta):
252+
ctx.alpha = 0.9943258522851727
253+
ctx.beta = beta
254+
ctx.max_value = 2 ** bit_depth - 1
255+
ctx.save_for_backward(input)
256+
257+
return input.clamp(min=0, max=ctx.max_value)
258+
259+
@staticmethod
260+
def backward(ctx, grad_output):
261+
grad_input = None
262+
(input,) = ctx.saved_tensors
263+
264+
grad_input = grad_output.clone()
265+
grad_sub = (
266+
torch.exp(
267+
(-ctx.alpha ** ctx.beta)
268+
* torch.abs(2.0 * input / ctx.max_value - 1) ** ctx.beta
269+
)
270+
* grad_output.clone()
271+
)
272+
273+
grad_input[input < 0] = grad_sub[input < 0]
274+
grad_input[input > ctx.max_value] = grad_sub[input > ctx.max_value]
275+
276+
return grad_input, None, None

compressai/models/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# limitations under the License.
1414

1515
from .google import *
16-
from .waseda import *
16+
from .waseda import *

0 commit comments

Comments
 (0)