|
49 | 49 | import org.deeplearning4j.util.ModelSerializer;
|
50 | 50 | import org.deeplearning4j.zoo.model.TinyYOLO;
|
51 | 51 | import org.nd4j.linalg.activations.Activation;
|
| 52 | +import org.nd4j.linalg.api.memory.enums.DebugMode; |
52 | 53 | import org.nd4j.linalg.api.ndarray.INDArray;
|
53 | 54 | import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
54 | 55 | import org.nd4j.linalg.factory.Nd4j;
|
55 | 56 | import org.nd4j.linalg.learning.config.Adam;
|
| 57 | +import org.nd4j.linalg.profiler.ProfilerConfig; |
56 | 58 | import org.slf4j.Logger;
|
57 | 59 | import org.slf4j.LoggerFactory;
|
58 | 60 |
|
@@ -122,18 +124,17 @@ public static void main(String[] args) throws java.lang.Exception {
|
122 | 124 | File trainDir = fetcher.getDataSetPath(DataSetType.TRAIN);
|
123 | 125 | File testDir = fetcher.getDataSetPath(DataSetType.TEST);
|
124 | 126 |
|
125 |
| - |
126 | 127 | log.info("Load data...");
|
127 | 128 |
|
128 | 129 | FileSplit trainData = new FileSplit(trainDir, NativeImageLoader.ALLOWED_FORMATS, rng);
|
129 | 130 | FileSplit testData = new FileSplit(testDir, NativeImageLoader.ALLOWED_FORMATS, rng);
|
130 | 131 |
|
131 | 132 | ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(height, width, nChannels,
|
132 |
| - gridHeight, gridWidth, new SvhnLabelProvider(trainDir)); |
| 133 | + gridHeight, gridWidth, new SvhnLabelProvider(trainDir)); |
133 | 134 | recordReaderTrain.initialize(trainData);
|
134 | 135 |
|
135 | 136 | ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels,
|
136 |
| - gridHeight, gridWidth, new SvhnLabelProvider(testDir)); |
| 137 | + gridHeight, gridWidth, new SvhnLabelProvider(testDir)); |
137 | 138 | recordReaderTest.initialize(testData);
|
138 | 139 |
|
139 | 140 | // ObjectDetectionRecordReader performs regression, so we need to specify it here
|
@@ -210,7 +211,7 @@ public static void main(String[] args) throws java.lang.Exception {
|
210 | 211 | CanvasFrame frame = new CanvasFrame("HouseNumberDetection");
|
211 | 212 | OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
|
212 | 213 | org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout =
|
213 |
| - (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer)model.getOutputLayer(0); |
| 214 | + (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer)model.getOutputLayer(0); |
214 | 215 | List<String> labels = train.getLabels();
|
215 | 216 | test.setCollectMetaData(true);
|
216 | 217 | Scalar[] colormap = {RED,BLUE,GREEN,CYAN,YELLOW,MAGENTA,ORANGE,PINK,LIGHTBLUE,VIOLET};
|
|
0 commit comments