Skip to content

Commit 8e1d9ea

Browse files
jeremiedbhetong007
authored andcommitted
Fix for RNN R API (#45)
* blog RNN bucketing mxnet R * apply review fixes * update RNN R * typo RNN R
1 parent 4c2a6d0 commit 8e1d9ea

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

_posts/2017-10-11-rnn-bucket-mxnet-R.md

+16-16
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ categories: rstats
77
comments: true
88
---
99

10-
This tutorial presents an example of application of RNN to text classification using padded and bucketed data to efficiently handle sequences of varying lengths. Some functionalities require running on a GPU with CUDA.
10+
This tutorial presents an example of application of RNN to text classification using padded and bucketed data to efficiently handle sequences of varying lengths. Some functionalities require running on a CUDA enabled GPU.
1111

1212
Example based on sentiment analysis on the [IMDB data](http://ai.stanford.edu/~amaas/data/sentiment/).
1313

@@ -45,7 +45,7 @@ To illustrate the benefit of bucketing, two datasets are created:
4545
- `corpus_single_train.rds`: no bucketing, all samples are padded/trimmed to 600 words.
4646
- `corpus_bucketed_train.rds`: samples split into 5 buckets of length 100, 150, 250, 400 and 600.
4747

48-
Below is the example of the assignation of the bucketed data and labels into `mx.io.bucket.iter` iterator. This iterator behaves essentially the same as the `mx.io.arrayiter` except that is pushes samples coming from the different buckets along with a bucketID to identify the appropriate network to use.
48+
Below is the example of the assignation of the bucketed data and labels into `mx.io.bucket.iter` iterator. This iterator behaves essentially the same as the `mx.io.arrayiter` except that is pushes samples coming from the different buckets along with a bucketID to identify the appropriate symbolic graph to use.
4949

5050
``` r
5151
corpus_bucketed_train <- readRDS(file = "data/corpus_bucketed_train.rds")
@@ -68,14 +68,14 @@ eval.data.bucket <- mx.io.bucket.iter(buckets = corpus_bucketed_test$buckets,
6868
Define the architecture
6969
-----------------------
7070

71-
Below are the graph representations of a seq-to-one architecture with LSTM cells. Note that input data is of shape `batch.size X seq.length` while the output of the RNN operator is of shape `hidden.features X batch.size X seq.length`.
71+
Below are the graph representations of a seq-to-one architecture with LSTM cells. Note that input data is of shape `seq.length X batch.size` while the RNN operator requires input of shape `hidden.features X batch.size X seq.length`, requiring to swap axis.
7272

7373
For bucketing, a list of symbols is defined, one for each bucket length. During training, at each batch the appropriate symbol is bound according to the bucketID provided by the iterator.
7474

7575
``` r
76-
symbol_single <- rnn.graph(config = "seq-to-one", cell.type = "lstm",
77-
num.rnn.layer = 1, num.embed = 2, num.hidden = 4,
78-
num.decode = 2, input.size = vocab, dropout = 0.5,
76+
symbol_single <- rnn.graph(config = "seq-to-one", cell_type = "lstm",
77+
num_rnn_layer = 1, num_embed = 2, num_hidden = 4,
78+
num_decode = 2, input_size = vocab, dropout = 0.5,
7979
ignore_label = -1, loss_output = "softmax",
8080
output_last_state = F, masking = T)
8181
```
@@ -84,11 +84,11 @@ symbol_single <- rnn.graph(config = "seq-to-one", cell.type = "lstm",
8484
bucket_list <- unique(c(train.data.bucket$bucket.names, eval.data.bucket$bucket.names))
8585

8686
symbol_buckets <- sapply(bucket_list, function(seq) {
87-
rnn.graph(config = "seq-to-one", cell.type = "lstm",
88-
num.rnn.layer = 1, num.embed = 2, num.hidden = 4,
89-
num.decode = 2, input.size = vocab, dropout = 0.5,
90-
ignore_label = -1, loss_output = "softmax",
91-
output_last_state = F, masking = T)})
87+
rnn.graph(config = "seq-to-one", cell_type = "lstm",
88+
num_rnn_layer = 1, num_embed = 2, num_hidden = 4,
89+
num_decode = 2, input_size = vocab, dropout = 0.5,
90+
ignore_label = -1, loss_output = "softmax",
91+
output_last_state = F, masking = T)})
9292

9393
graph.viz(symbol_single, type = "graph", direction = "LR",
9494
graph.height.px = 50, graph.width.px = 800, shape=c(64, 5))
@@ -118,7 +118,7 @@ batch.end.callback <- mx.callback.log.train.metric(period = 50)
118118
system.time(
119119
model <- mx.model.buckets(symbol = symbol_single,
120120
train.data = train.data.single, eval.data = eval.data.single,
121-
num.round = 5, ctx = devices, verbose = FALSE,
121+
num.round = 6, ctx = devices, verbose = FALSE,
122122
metric = mx.metric.accuracy, optimizer = optimizer,
123123
initializer = initializer,
124124
batch.end.callback = NULL,
@@ -148,7 +148,7 @@ batch.end.callback <- mx.callback.log.train.metric(period = 50)
148148
system.time(
149149
model <- mx.model.buckets(symbol = symbol_buckets,
150150
train.data = train.data.bucket, eval.data = eval.data.bucket,
151-
num.round = 5, ctx = devices, verbose = FALSE,
151+
num.round = 6, ctx = devices, verbose = FALSE,
152152
metric = mx.metric.accuracy, optimizer = optimizer,
153153
initializer = initializer,
154154
batch.end.callback = NULL,
@@ -170,12 +170,12 @@ Word representation can be visualized by looking at the assigned weights in any
170170

171171
![](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/blog_mxnet_R_rnn_bucket/embed-1.png)
172172

173-
Since the model attempts to predict the sentiment, it's no surprise that the 2 dimensions into which each word is projected appear correlated with words' polarity. Positive words are associated with lower X1 values ("great", "excellent"), while the most negative words appear at the far right ("terrible", "worst"). By representing words of similar meaning with features of values, embedding much facilitates the remaining classification task for the network.
173+
Since the model attempts to predict the sentiment, it's no surprise that the 2 dimensions into which each word is projected appear correlated with words' polarity. Positive words are associated with lower X1 values ("great", "excellent"), while the most negative words appear at the far right ("terrible", "worst"). By representing words of similar meaning with features of similar values, embedding much facilitates the remaining classification task for the network.
174174

175175
Inference on test data
176176
----------------------
177177

178-
The utility function `mx.infer.buckets` has been added to simplify inference on RNN with bucketed data.
178+
The utility function `mx.infer.rnn` has been added to simplify inference on RNN with bucketed data.
179179

180180
``` r
181181
ctx <- mx.gpu(0)
@@ -189,7 +189,7 @@ test.data <- mx.io.bucket.iter(buckets = corpus_bucketed_test$buckets,
189189
```
190190

191191
``` r
192-
infer <- mx.infer.buckets(infer.data = test.data, model = model, ctx = ctx)
192+
infer <- mx.infer.rnn(infer.data = test.data, model = model, ctx = ctx)
193193

194194
pred_raw <- t(as.array(infer))
195195
pred <- max.col(pred_raw, tie = "first") - 1

0 commit comments

Comments
 (0)