From 0b41f2e6867ca974a670c9b55079976d438204b0 Mon Sep 17 00:00:00 2001 From: zdl Date: Sat, 17 Apr 2021 22:35:07 +0800 Subject: [PATCH 1/3] gan example --- dl4j-gan-examples/pom.xml | 108 +++++++++++++++ .../ganexamples/MNISTVisualizer.java | 75 +++++++++++ .../deeplearning4j/ganexamples/SimpleGan.java | 123 ++++++++++++++++++ 3 files changed, 306 insertions(+) create mode 100644 dl4j-gan-examples/pom.xml create mode 100644 dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/MNISTVisualizer.java create mode 100644 dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java diff --git a/dl4j-gan-examples/pom.xml b/dl4j-gan-examples/pom.xml new file mode 100644 index 0000000000..a64e4f5d26 --- /dev/null +++ b/dl4j-gan-examples/pom.xml @@ -0,0 +1,108 @@ + + + 4.0.0 + + org.deeplearning4j + dl4j-gan-examples + 1.0.0-SNAPSHOT + + + + 1.0.0-SNAPSHOT + 1.2.3 + 1.8 + 2.4.3 + UTF-8 + + + + + + org.deeplearning4j + deeplearning4j-core + ${dl4j-master.version} + + + org.nd4j + nd4j-native + + 1.0.0-beta7 + + + ch.qos.logback + logback-classic + ${logback.version} + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.5.1 + + ${java.version} + ${java.version} + + + + + + + org.apache.maven.plugins + maven-shade-plugin + ${maven-shade-plugin.version} + + true + bin + true + + + *:* + + org/datanucleus/** + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + package + + shade + + + + + reference.conf + + + + + + + + + + + + diff --git a/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/MNISTVisualizer.java b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/MNISTVisualizer.java new file mode 100644 index 0000000000..04f57c5537 --- /dev/null +++ b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/MNISTVisualizer.java @@ -0,0 +1,75 @@ +package org.deeplearning4j.ganexamples; + +import org.nd4j.linalg.api.ndarray.INDArray; + +import javax.swing.*; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.util.ArrayList; +import java.util.List; + +/** + * @author zdl + */ +public class MNISTVisualizer { + private double imageScale; + private List digits; + private String title; + private int gridWidth; + private JFrame frame; + + public MNISTVisualizer(double imageScale, String title) { + this(imageScale, title, 5); + } + + public MNISTVisualizer(double imageScale, String title, int gridWidth) { + this.imageScale = imageScale; + this.title = title; + this.gridWidth = gridWidth; + } + + public void visualize() { + if (null != frame) { + frame.dispose(); + } + frame = new JFrame(); + frame.setTitle(title); + frame.setSize(800, 600); + JPanel panel = new JPanel(); + panel.setPreferredSize(new Dimension(800, 600)); + panel.setLayout(new GridLayout(0, gridWidth)); + List list = getComponents(); + for (JLabel image : list) { + panel.add(image); + } + + frame.add(panel); + frame.setVisible(true); + frame.pack(); + } + + public List getComponents() { + List images = new ArrayList(); + for (INDArray arr : digits) { + BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY); + for (int i = 0; i < 784; i++) { + bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * arr.getDouble(i))); + } + ImageIcon orig = new ImageIcon(bi); + Image imageScaled = orig.getImage().getScaledInstance((int) (imageScale * 28), (int) (imageScale * 28), + Image.SCALE_DEFAULT); + ImageIcon scaled = new ImageIcon(imageScaled); + images.add(new JLabel(scaled)); + } + return images; + } + + public List getDigits() { + return digits; + } + + public void setDigits(List digits) { + this.digits = digits; + } + +} diff --git a/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java new file mode 100644 index 0000000000..efb6c71715 --- /dev/null +++ b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java @@ -0,0 +1,123 @@ +package org.deeplearning4j.ganexamples; + +import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; +import org.deeplearning4j.nn.conf.BackpropType; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.graph.StackVertex; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.RmsProp; +import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * @author zdl + */ +public class SimpleGan { + static double lr = 0.005; + private static final Logger log = LoggerFactory.getLogger(SimpleGan.class); + private static String[] generatorLayerNames = new String[]{"g1", "g2", "g3"}; + private static String[] discriminatorLayerNames = new String[]{"d1", "d2", "d3", "out"}; + + public static void main(String[] args) throws Exception { + + final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new RmsProp(lr)) + .weightInit(WeightInit.XAVIER); + + final ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder().backpropType(BackpropType.Standard) + .addInputs("input1", "input2") + .addLayer("g1", + new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build(), + "input1") + .addLayer("g2", + new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build(), + "g1") + .addLayer("g3", + new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build(), + "g2") + .addVertex("stack", new StackVertex(), "input2", "g3") + .addLayer("d1", + new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build(), + "stack") + .addLayer("d2", + new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build(), + "d1") + .addLayer("d3", + new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build(), + "d2") + .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1) + .activation(Activation.SIGMOID).build(), "d3") + .setOutputs("out"); + + ComputationGraph net = new ComputationGraph(graphBuilder.build()); + net.init(); + net.setListeners(new ScoreIterationListener(1)); + System.out.println(net.summary()); + DataSetIterator train = new MnistDataSetIterator(30, true, 12345); + INDArray discriminatorLabel = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1)); + INDArray generatorLabel = Nd4j.ones(60, 1); + MNISTVisualizer bestVisualizer = new MNISTVisualizer(1, "Gan"); + for (int i = 1; i <= 100000; i++) { + if (!train.hasNext()) { + train.reset(); + } + INDArray trueData = train.next().getFeatures(); + INDArray z = Nd4j.rand(new NormalDistribution(), new long[]{30, 10}); + MultiDataSet dataSetD = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{z, trueData}, + new INDArray[]{discriminatorLabel}); + trainDiscriminator(net, dataSetD); + z = Nd4j.rand(new NormalDistribution(), new long[]{30, 10}); + MultiDataSet dataSetG = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{z, trueData}, + new INDArray[]{generatorLabel}); + trainGenerator(net, dataSetG); + if (i % 100 == 0) { + DataSetIterator dataSetIterator = new MnistDataSetIterator(30, true, 12345); + INDArray data = dataSetIterator.next().getFeatures(); + Map map = net.feedForward( + new INDArray[]{Nd4j.rand(new NormalDistribution(), new long[]{50, 10}), data}, false); + INDArray indArray = map.get("g3"); + + List list = new ArrayList<>(); + for (int j = 0; j < indArray.size(0); j++) { + list.add(indArray.getRow(j)); + } + bestVisualizer.setDigits(list); + bestVisualizer.visualize(); + } + } + + } + + public static void trainDiscriminator(ComputationGraph net, MultiDataSet dataSet) { + net.setTrainable(discriminatorLayerNames, true); + net.setTrainable(generatorLayerNames, false); + net.fit(dataSet); + } + + public static void trainGenerator(ComputationGraph net, MultiDataSet dataSet) { + net.setTrainable(discriminatorLayerNames, false); + net.setTrainable(generatorLayerNames, true); + net.fit(dataSet); + } +} From 17fc89c09b0cd9af0c4efa4ed3fd3cfea3e5a57d Mon Sep 17 00:00:00 2001 From: zdl Date: Sat, 24 Apr 2021 20:45:30 +0800 Subject: [PATCH 2/3] gan example --- dl4j-gan-examples/pom.xml | 11 +- .../deeplearning4j/ganexamples/SimpleGan.java | 148 +++++++++--------- 2 files changed, 78 insertions(+), 81 deletions(-) diff --git a/dl4j-gan-examples/pom.xml b/dl4j-gan-examples/pom.xml index a64e4f5d26..1b3f4bb647 100644 --- a/dl4j-gan-examples/pom.xml +++ b/dl4j-gan-examples/pom.xml @@ -10,7 +10,7 @@ - 1.0.0-SNAPSHOT + 1.0.0-beta7 1.2.3 1.8 2.4.3 @@ -27,14 +27,19 @@ org.nd4j nd4j-native - - 1.0.0-beta7 + ${dl4j-master.version} ch.qos.logback logback-classic ${logback.version} + + junit + junit + 4.12 + compile + diff --git a/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java index efb6c71715..9452e469b1 100644 --- a/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java +++ b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java @@ -1,19 +1,16 @@ package org.deeplearning4j.ganexamples; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.nn.conf.BackpropType; -import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.graph.StackVertex; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; -import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution; -import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.RmsProp; @@ -23,101 +20,96 @@ import java.util.ArrayList; import java.util.List; -import java.util.Map; /** + * ***** ******** ***************** + * z ---- * G *----* G(z) * ------ * discriminator * ---- fake + * ***** ******** * * + * x ----------------------------- ***************** ---- real + * * @author zdl */ public class SimpleGan { - static double lr = 0.005; - private static final Logger log = LoggerFactory.getLogger(SimpleGan.class); - private static String[] generatorLayerNames = new String[]{"g1", "g2", "g3"}; - private static String[] discriminatorLayerNames = new String[]{"d1", "d2", "d3", "out"}; public static void main(String[] args) throws Exception { - final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new RmsProp(lr)) - .weightInit(WeightInit.XAVIER); + /** + *Build the discriminator + */ + MultiLayerConfiguration discriminatorConf = new NeuralNetConfiguration.Builder().seed(12345) + .weightInit(WeightInit.XAVIER).updater(new RmsProp(0.001)) + .list() + .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(512).activation(Activation.RELU).build()) + .layer(1, new DenseLayer.Builder().activation(Activation.RELU) + .nIn(512).nOut(256).build()) + .layer(2, new DenseLayer.Builder().activation(Activation.RELU) + .nIn(256).nOut(128).build()) + .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.XENT) + .activation(Activation.SIGMOID).nIn(128).nOut(1).build()).build(); - final ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder().backpropType(BackpropType.Standard) - .addInputs("input1", "input2") - .addLayer("g1", - new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build(), - "input1") - .addLayer("g2", - new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build(), - "g1") - .addLayer("g3", - new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build(), - "g2") - .addVertex("stack", new StackVertex(), "input2", "g3") - .addLayer("d1", - new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build(), - "stack") - .addLayer("d2", - new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build(), - "d1") - .addLayer("d3", - new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build(), - "d2") - .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1) - .activation(Activation.SIGMOID).build(), "d3") - .setOutputs("out"); - ComputationGraph net = new ComputationGraph(graphBuilder.build()); - net.init(); - net.setListeners(new ScoreIterationListener(1)); - System.out.println(net.summary()); + MultiLayerConfiguration ganConf = new NeuralNetConfiguration.Builder().seed(12345) + .weightInit(WeightInit.XAVIER) + //generator + .updater(new RmsProp(0.001)).list() + .layer(0, new DenseLayer.Builder().nIn(20).nOut(256).activation(Activation.RELU).build()) + .layer(1, new DenseLayer.Builder().activation(Activation.RELU) + .nIn(256).nOut(512).build()) + .layer(2, new DenseLayer.Builder().activation(Activation.RELU) + .nIn(512).nOut(28 * 28).build()) + //Freeze the discriminator parameter + .layer(3, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(28 * 28).nOut(512).activation(Activation.RELU).build())) + .layer(4, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(512).nOut(256).activation(Activation.RELU).build())) + .layer(5, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU).build())) + .layer(6, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.XENT) + .activation(Activation.SIGMOID).nIn(128).nOut(1).build())).build(); + + + MultiLayerNetwork discriminatorNetwork = new MultiLayerNetwork(discriminatorConf); + discriminatorNetwork.init(); + System.out.println(discriminatorNetwork.summary()); + discriminatorNetwork.setListeners(new ScoreIterationListener(1)); + + MultiLayerNetwork ganNetwork = new MultiLayerNetwork(ganConf); + ganNetwork.init(); + ganNetwork.setListeners(new ScoreIterationListener(1)); + System.out.println(ganNetwork.summary()); + DataSetIterator train = new MnistDataSetIterator(30, true, 12345); - INDArray discriminatorLabel = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1)); - INDArray generatorLabel = Nd4j.ones(60, 1); - MNISTVisualizer bestVisualizer = new MNISTVisualizer(1, "Gan"); + + INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1)); + INDArray labelG = Nd4j.ones(30, 1); + MNISTVisualizer mnistVisualizer = new MNISTVisualizer(1, "Gan"); for (int i = 1; i <= 100000; i++) { if (!train.hasNext()) { train.reset(); } - INDArray trueData = train.next().getFeatures(); - INDArray z = Nd4j.rand(new NormalDistribution(), new long[]{30, 10}); - MultiDataSet dataSetD = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{z, trueData}, - new INDArray[]{discriminatorLabel}); - trainDiscriminator(net, dataSetD); - z = Nd4j.rand(new NormalDistribution(), new long[]{30, 10}); - MultiDataSet dataSetG = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{z, trueData}, - new INDArray[]{generatorLabel}); - trainGenerator(net, dataSetG); - if (i % 100 == 0) { - DataSetIterator dataSetIterator = new MnistDataSetIterator(30, true, 12345); - INDArray data = dataSetIterator.next().getFeatures(); - Map map = net.feedForward( - new INDArray[]{Nd4j.rand(new NormalDistribution(), new long[]{50, 10}), data}, false); - INDArray indArray = map.get("g3"); - + INDArray trueImage = train.next().getFeatures(); + INDArray z = Nd4j.rand(new NormalDistribution(), new long[]{30, 20}); + List ganFeedForward = ganNetwork.feedForward(z, false); + INDArray fakeImage = ganFeedForward.get(3); + INDArray trainDiscriminatorFeatures = Nd4j.vstack(trueImage, fakeImage); + //Training discriminator + discriminatorNetwork.fit(trainDiscriminatorFeatures, labelD); + copyDiscriminatorParam(discriminatorNetwork, ganNetwork); + //Training generator + ganNetwork.fit(z, labelG); + if (i % 1000 == 0) { + List indArrays = ganNetwork.feedForward(Nd4j.rand(new NormalDistribution(), new long[]{30, 20}), false); List list = new ArrayList<>(); + INDArray indArray = indArrays.get(3); for (int j = 0; j < indArray.size(0); j++) { list.add(indArray.getRow(j)); } - bestVisualizer.setDigits(list); - bestVisualizer.visualize(); + mnistVisualizer.setDigits(list); + mnistVisualizer.visualize(); } } - } - public static void trainDiscriminator(ComputationGraph net, MultiDataSet dataSet) { - net.setTrainable(discriminatorLayerNames, true); - net.setTrainable(generatorLayerNames, false); - net.fit(dataSet); - } - - public static void trainGenerator(ComputationGraph net, MultiDataSet dataSet) { - net.setTrainable(discriminatorLayerNames, false); - net.setTrainable(generatorLayerNames, true); - net.fit(dataSet); + public static void copyDiscriminatorParam(MultiLayerNetwork discriminatorNetwork, MultiLayerNetwork ganNetwork) { + for (int i = 0; i <= 3; i++) { + ganNetwork.getLayer(i + 3).setParams(discriminatorNetwork.getLayer(i).params()); + } } } From 63fb13904c2631275b7b87bc30f430e79503b2e2 Mon Sep 17 00:00:00 2001 From: zdl Date: Sat, 24 Apr 2021 20:54:05 +0800 Subject: [PATCH 3/3] gan example --- dl4j-gan-examples/README.md | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 dl4j-gan-examples/README.md diff --git a/dl4j-gan-examples/README.md b/dl4j-gan-examples/README.md new file mode 100644 index 0000000000..0143fef903 --- /dev/null +++ b/dl4j-gan-examples/README.md @@ -0,0 +1,6 @@ +An example of a simple gan implemented with DL4J + + ***** ******** ***************** + z ---- * G *----* G(z) * ------ * discriminator * ---- fake + ***** ******** * * + x ----------------------------- ***************** ---- real