Skip to content

Commit b825ad5

Browse files
author
Andrea
authored
PointNet 1.0 implementation. (#248)
1 parent bacee9a commit b825ad5

File tree

4 files changed

+363
-0
lines changed

4 files changed

+363
-0
lines changed

tensorflow_graphics/nn/layer/BUILD

+20
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,23 @@ py_test(
6767
"//tensorflow_graphics/util:test_case",
6868
],
6969
)
70+
71+
py_library(
72+
name = "pointnet",
73+
srcs = ["pointnet.py"],
74+
srcs_version = "PY2AND3",
75+
deps = [
76+
# "//tensorflow_graphics/util:export_api",
77+
],
78+
)
79+
80+
py_test(
81+
name = "pointnet_test",
82+
srcs = ["tests/pointnet_test.py"],
83+
srcs_version = "PY2AND3",
84+
size = "small",
85+
deps = [
86+
":pointnet",
87+
"//tensorflow_graphics/util:test_case",
88+
],
89+
)
+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
#Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Implementation of the PointNet networks from:
16+
17+
@inproceedings{qi2017pointnet,
18+
title={Pointnet: Deep learning on point sets
19+
for3d classification and segmentation},
20+
author={Qi, Charles R and Su, Hao and Mo, Kaichun and Guibas, Leonidas J},
21+
booktitle={Proceedings of the IEEE conference on computer vision and pattern
22+
recognition},
23+
pages={652--660},
24+
year={2017}}
25+
26+
NOTE: scheduling of batchnorm momentum currently not available in keras. However
27+
experimentally, using the batch norm from Keras resulted in better test accuracy
28+
(+1.5%) than the author's [custom batch norm
29+
version](https://github.com/charlesq34/pointnet/blob/master/utils/tf_util.py)
30+
even when coupled with batchnorm momentum decay. Further, note the author's
31+
version is actually performing a "global normalization", as mentioned in the
32+
[tf.nn.moments documentation]
33+
(https://www.tensorflow.org/api_docs/python/tf/nn/moments).
34+
35+
This shorthand notation is used throughout this module:
36+
`B`: Number of elements in a batch.
37+
`N`: The number of points in the point set.
38+
`D`: Number of dimensions (e.g. 2 for 2D, 3 for 3D).
39+
`C`: The number of feature channels.
40+
"""
41+
42+
import tensorflow as tf
43+
from tensorflow.keras.layers import Layer
44+
from tensorflow.keras.layers import Dense
45+
from tensorflow.keras.layers import Conv2D
46+
from tensorflow.keras.layers import Dropout
47+
from tensorflow.keras.layers import BatchNormalization
48+
49+
50+
class PointNetConv2Layer(Layer):
51+
"""The 2D convolution layer used by the feature encoder in PointNet."""
52+
53+
def __init__(self, channels, momentum):
54+
"""Constructs a Conv2 layer.
55+
56+
Note:
57+
Differently from the standard Keras Conv2 layer, the order of ops is:
58+
1. fully connected layer
59+
2. batch normalization layer
60+
3. ReLU activation unit
61+
62+
Args:
63+
channels: the number of generated feature.
64+
momentum: the momentum of the batch normalization layer.
65+
"""
66+
super(PointNetConv2Layer, self).__init__()
67+
self.channels = channels
68+
self.momentum = momentum
69+
70+
def build(self, input_shape):
71+
"""Builds the layer with a specified input_shape."""
72+
self.conv = Conv2D(self.channels, (1, 1), input_shape=input_shape)
73+
self.bn = BatchNormalization(momentum=self.momentum)
74+
75+
def call(self, inputs, training=None):
76+
"""Executes the convolution.
77+
78+
Args:
79+
inputs: a dense tensor of size `[B, N, 1, D]`.
80+
training: flag to control batch normalization update statistics.
81+
82+
Returns:
83+
Tensor with shape `[B, N, 1, C]`.
84+
"""
85+
return tf.nn.relu(self.bn(self.conv(inputs), training))
86+
87+
88+
class PointNetDenseLayer(Layer):
89+
"""The fully connected layer used by the classification head in pointnet.
90+
91+
Note:
92+
Differently from the standard Keras Conv2 layer, the order of ops is:
93+
1. fully connected layer
94+
2. batch normalization layer
95+
3. ReLU activation unit
96+
"""
97+
98+
def __init__(self, channels, momentum):
99+
super(PointNetDenseLayer, self).__init__()
100+
self.momentum = momentum
101+
self.channels = channels
102+
103+
def build(self, input_shape):
104+
"""Builds the layer with a specified input_shape."""
105+
self.dense = Dense(self.channels, input_shape=input_shape)
106+
self.bn = BatchNormalization(momentum=self.momentum)
107+
108+
def call(self, inputs, training=None):
109+
"""Executes the convolution.
110+
111+
Args:
112+
inputs: a dense tensor of size `[B, D]`.
113+
training: flag to control batch normalization update statistics.
114+
115+
Returns:
116+
Tensor with shape `[B, C]`.
117+
"""
118+
return tf.nn.relu(self.bn(self.dense(inputs), training))
119+
120+
121+
class VanillaEncoder(Layer):
122+
"""The Vanilla PointNet feature encoder.
123+
124+
Consists of five conv2 layers with (64,64,64,128,1024) output channels.
125+
126+
Note:
127+
PointNetConv2Layer are used instead of tf.keras.layers.Conv2D.
128+
129+
https://github.com/charlesq34/pointnet/blob/master/models/pointnet_cls_basic.py
130+
"""
131+
132+
def __init__(self, momentum=.5):
133+
"""Constructs a VanillaEncoder keras layer.
134+
135+
Args:
136+
momentum: the momentum used for the batch normalization layer.
137+
"""
138+
super(VanillaEncoder, self).__init__()
139+
self.conv1 = PointNetConv2Layer(64, momentum)
140+
self.conv2 = PointNetConv2Layer(64, momentum)
141+
self.conv3 = PointNetConv2Layer(64, momentum)
142+
self.conv4 = PointNetConv2Layer(128, momentum)
143+
self.conv5 = PointNetConv2Layer(1024, momentum)
144+
145+
def call(self, inputs, training=None):
146+
"""Computes the PointNet features.
147+
148+
Args:
149+
inputs: a dense tensor of size `[B,N,D]`.
150+
training: flag to control batch normalization update statistics.
151+
152+
Returns:
153+
Tensor with shape `[B, N, C=1024]`
154+
"""
155+
x = tf.expand_dims(inputs, axis=2) # [B,N,1,D]
156+
x = self.conv1(x, training) # [B,N,1,64]
157+
x = self.conv2(x, training) # [B,N,1,64]
158+
x = self.conv3(x, training) # [B,N,1,64]
159+
x = self.conv4(x, training) # [B,N,1,128]
160+
x = self.conv5(x, training) # [B,N,1,1024]
161+
x = tf.math.reduce_max(x, axis=1) # [B,1,1024]
162+
return tf.squeeze(x) # [B,1024]
163+
164+
165+
class ClassificationHead(Layer):
166+
"""The PointNet classification head.
167+
168+
The head consists of 2x PointNetDenseLayer layers (512 and 256 channels)
169+
followed by a dropout layer (drop rate=30%) a dense linear layer producing the
170+
logits of the num_classes classes.
171+
"""
172+
173+
def __init__(self, num_classes=40, momentum=0.5, dropout_rate=0.3):
174+
"""Constructor.
175+
176+
Args:
177+
num_classes: the number of classes to classify.
178+
momentum: the momentum used for the batch normalization layer.
179+
dropout_rate: the dropout rate for fully connected layer
180+
"""
181+
super(ClassificationHead, self).__init__()
182+
self.dense1 = PointNetDenseLayer(512, momentum)
183+
self.dense2 = PointNetDenseLayer(256, momentum)
184+
self.dropout = Dropout(dropout_rate)
185+
self.dense3 = Dense(num_classes, activation="linear")
186+
187+
def call(self, inputs, training=None):
188+
"""Computes the classifiation logits given features (note: without softmax).
189+
190+
Args:
191+
inputs: tensor of points with shape `[B,D]`.
192+
training: flag for batch normalization and dropout training.
193+
194+
Returns:
195+
Tensor with shape `[B,num_classes]`
196+
"""
197+
x = self.dense1(inputs, training) # [B,512]
198+
x = self.dense2(x, training) # [B,256]
199+
x = self.dropout(x, training) # [B,256]
200+
return self.dense3(x) # [B,num_classes)
201+
202+
203+
class PointNetVanillaClassifier(Layer):
204+
"""The PointNet 'Vanilla' classifier (i.e. without spatial transformer)."""
205+
206+
def __init__(self, num_classes=40, momentum=.5, dropout_rate=.3):
207+
"""Constructor.
208+
209+
Args:
210+
num_classes: the number of classes to classify.
211+
momentum: the momentum used for the batch normalization layer.
212+
dropout_rate: the dropout rate for the classification head.
213+
"""
214+
super(PointNetVanillaClassifier, self).__init__()
215+
self.encoder = VanillaEncoder(momentum)
216+
self.classifier = ClassificationHead(num_classes=num_classes,
217+
momentum=momentum,
218+
dropout_rate=dropout_rate)
219+
220+
def call(self, points, training=None):
221+
"""Computes the classifiation logits of a point set.
222+
223+
Args:
224+
points: a tensor of points with shape `[B, D]`
225+
training: for batch normalization and dropout training.
226+
227+
Returns:
228+
Tensor with shape `[B,num_classes]`
229+
"""
230+
features = self.encoder(points, training) # (B,1024)
231+
logits = self.classifier(features, training) # (B,num_classes)
232+
return logits
233+
234+
@staticmethod
235+
def loss(labels, logits):
236+
"""The classification model training loss.
237+
238+
Note:
239+
see tf.nn.sparse_softmax_cross_entropy_with_logits
240+
241+
Args:
242+
labels: a tensor with shape `[B,]`
243+
logits: a tensor with shape `[B,num_classes]`
244+
"""
245+
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits
246+
residual = cross_entropy(labels, logits)
247+
return tf.reduce_mean(residual)

tensorflow_graphics/nn/layer/tests/graph_convolution_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def _dummy_data(batch_size, num_vertices, num_channels):
4545
return data, neighbors
4646

4747

48+
# pylint: disable=missing-class-docstring
4849
class GraphConvolutionTestFeatureSteeredConvolutionLayerTests(
4950
test_case.TestCase):
5051

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for pointnet layers."""
15+
16+
17+
import tensorflow as tf
18+
from absl.testing import parameterized
19+
from tensorflow_graphics.util import test_case
20+
from tensorflow_graphics.nn.layer.pointnet import ClassificationHead
21+
from tensorflow_graphics.nn.layer.pointnet import PointNetConv2Layer
22+
from tensorflow_graphics.nn.layer.pointnet import PointNetDenseLayer
23+
from tensorflow_graphics.nn.layer.pointnet import VanillaEncoder
24+
from tensorflow_graphics.nn.layer.pointnet import PointNetVanillaClassifier
25+
26+
27+
class RandomForwardExecutionTest(test_case.TestCase):
28+
29+
@parameterized.parameters(
30+
((32, 2048, 1, 3), (32), (.5), True),
31+
((32, 2048, 1, 3), (32), (.5), False),
32+
((32, 2048, 1, 2), (16), (.99), True),
33+
)
34+
def test_conv2(self, input_shape, channels, momentum, training):
35+
B, N, X, _ = input_shape
36+
inputs = tf.random.uniform(input_shape)
37+
layer = PointNetConv2Layer(channels, momentum)
38+
outputs = layer(inputs, training=training)
39+
assert outputs.shape == (B, N, X, channels)
40+
41+
@parameterized.parameters(
42+
((32, 1024), (40), (.5), True),
43+
((32, 2048), (20), (.5), False),
44+
((32, 512), (10), (.99), True),
45+
)
46+
def test_dense(self, input_shape, channels, momentum, training):
47+
B, _ = input_shape
48+
inputs = tf.random.uniform(input_shape)
49+
layer = PointNetDenseLayer(channels, momentum)
50+
outputs = layer(inputs, training=training)
51+
assert outputs.shape == (B, channels)
52+
53+
@parameterized.parameters(
54+
((32, 2048, 3), (.9), True),
55+
((32, 2048, 2), (.5), False),
56+
((32, 2048, 3), (.99), True),
57+
)
58+
def test_vanilla_encoder(self, input_shape, momentum, training):
59+
B, N, D = input_shape
60+
inputs = tf.random.uniform(input_shape)
61+
encoder = VanillaEncoder(momentum)
62+
outputs = encoder(inputs, training=training)
63+
assert outputs.shape == (B, 1024)
64+
65+
@parameterized.parameters(
66+
((16, 1024), (20), (.9), True),
67+
((8, 2048), (40), (.5), False),
68+
((32, 512), (10), (.99), True),
69+
)
70+
def test_classification_head(self, input_shape, num_classes, momentum, training):
71+
B, C = input_shape
72+
inputs = tf.random.uniform(input_shape)
73+
head = ClassificationHead(num_classes, momentum)
74+
outputs = head(inputs, training=training)
75+
assert outputs.shape == (B, num_classes)
76+
77+
@parameterized.parameters(
78+
((32, 1024, 3), 40, True),
79+
((32, 1024, 2), 40, False),
80+
((16, 2048, 3), 20, True),
81+
((16, 2048, 2), 20, False),
82+
)
83+
def test_vanilla_classifier(self, input_shape, num_classes, training):
84+
B, N, D = input_shape
85+
C = num_classes
86+
inputs = tf.random.uniform(input_shape)
87+
model = PointNetVanillaClassifier(num_classes, momentum=.5)
88+
logits = model(inputs, training)
89+
assert logits.shape == (B, C)
90+
labels = tf.random.uniform((B, ), minval=0, maxval=C, dtype=tf.int64)
91+
PointNetVanillaClassifier.loss(labels, logits)
92+
93+
94+
if __name__ == "__main__":
95+
test_case.main()

0 commit comments

Comments
 (0)