-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathdecompositions.py
102 lines (82 loc) · 4.14 KB
/
decompositions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import tensorly as tl
from tensorly.decomposition import parafac, partial_tucker
import numpy as np
import torch
import torch.nn as nn
from VBMF import VBMF
def cp_decomposition_conv_layer(layer, rank):
""" Gets a conv layer and a target rank,
returns a nn.Sequential object with the decomposition """
# Perform CP decomposition on the layer weight tensorly.
last, first, vertical, horizontal = \
parafac(layer.weight.data, rank=rank, init='svd')
pointwise_s_to_r_layer = torch.nn.Conv2d(in_channels=first.shape[0], \
out_channels=first.shape[1], kernel_size=1, stride=1, padding=0,
dilation=layer.dilation, bias=False)
depthwise_vertical_layer = torch.nn.Conv2d(in_channels=vertical.shape[1],
out_channels=vertical.shape[1], kernel_size=(vertical.shape[0], 1),
stride=1, padding=(layer.padding[0], 0), dilation=layer.dilation,
groups=vertical.shape[1], bias=False)
depthwise_horizontal_layer = \
torch.nn.Conv2d(in_channels=horizontal.shape[1], \
out_channels=horizontal.shape[1],
kernel_size=(1, horizontal.shape[0]), stride=layer.stride,
padding=(0, layer.padding[0]),
dilation=layer.dilation, groups=horizontal.shape[1], bias=False)
pointwise_r_to_t_layer = torch.nn.Conv2d(in_channels=last.shape[1], \
out_channels=last.shape[0], kernel_size=1, stride=1,
padding=0, dilation=layer.dilation, bias=True)
pointwise_r_to_t_layer.bias.data = layer.bias.data
depthwise_horizontal_layer.weight.data = \
torch.transpose(horizontal, 1, 0).unsqueeze(1).unsqueeze(1)
depthwise_vertical_layer.weight.data = \
torch.transpose(vertical, 1, 0).unsqueeze(1).unsqueeze(-1)
pointwise_s_to_r_layer.weight.data = \
torch.transpose(first, 1, 0).unsqueeze(-1).unsqueeze(-1)
pointwise_r_to_t_layer.weight.data = last.unsqueeze(-1).unsqueeze(-1)
new_layers = [pointwise_s_to_r_layer, depthwise_vertical_layer, \
depthwise_horizontal_layer, pointwise_r_to_t_layer]
return nn.Sequential(*new_layers)
def estimate_ranks(layer):
""" Unfold the 2 modes of the Tensor the decomposition will
be performed on, and estimates the ranks of the matrices using VBMF
"""
weights = layer.weight.data
unfold_0 = tl.base.unfold(weights, 0)
unfold_1 = tl.base.unfold(weights, 1)
_, diag_0, _, _ = VBMF.EVBMF(unfold_0)
_, diag_1, _, _ = VBMF.EVBMF(unfold_1)
ranks = [diag_0.shape[0], diag_1.shape[1]]
return ranks
def tucker_decomposition_conv_layer(layer):
""" Gets a conv layer,
returns a nn.Sequential object with the Tucker decomposition.
The ranks are estimated with a Python implementation of VBMF
https://github.com/CasvandenBogaard/VBMF
"""
ranks = estimate_ranks(layer)
print(layer, "VBMF Estimated ranks", ranks)
core, [last, first] = \
partial_tucker(layer.weight.data, \
modes=[0, 1], rank=ranks, init='svd')
# A pointwise convolution that reduces the channels from S to R3
first_layer = torch.nn.Conv2d(in_channels=first.shape[0], \
out_channels=first.shape[1], kernel_size=1,
stride=1, padding=0, dilation=layer.dilation, bias=False)
# A regular 2D convolution layer with R3 input channels
# and R3 output channels
core_layer = torch.nn.Conv2d(in_channels=core.shape[1], \
out_channels=core.shape[0], kernel_size=layer.kernel_size,
stride=layer.stride, padding=layer.padding, dilation=layer.dilation,
bias=False)
# A pointwise convolution that increases the channels from R4 to T
last_layer = torch.nn.Conv2d(in_channels=last.shape[1], \
out_channels=last.shape[0], kernel_size=1, stride=1,
padding=0, dilation=layer.dilation, bias=True)
last_layer.bias.data = layer.bias.data
first_layer.weight.data = \
torch.transpose(first, 1, 0).unsqueeze(-1).unsqueeze(-1)
last_layer.weight.data = last.unsqueeze(-1).unsqueeze(-1)
core_layer.weight.data = core
new_layers = [first_layer, core_layer, last_layer]
return nn.Sequential(*new_layers)