1
+
2
+ '''
3
+ Tensorflow Code for a fourier transform network
4
+ '''
5
+
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+
10
+ import tensorflow as tf
11
+ import numpy as np
12
+ import matplotlib
13
+ import matplotlib .pyplot as plt
14
+ import os
15
+ os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '2'
16
+
17
+ # Import Dataset
18
+ from data_loader import DataLoader
19
+ data = DataLoader ()
20
+ print ('Data Loaded' )
21
+
22
+ # Training Parameters
23
+ learning_rate = 0.0001
24
+ num_steps = 10000
25
+ batch_size = 32
26
+ display_step = 100
27
+
28
+ # Network Parameters
29
+ WIDTH = data .WIDTH ; HEIGHT = data .HEIGHT ; CHANNELS = data .CHANNELS
30
+ NUM_INPUTS = WIDTH * HEIGHT * CHANNELS
31
+ NUM_OUTPUTS = WIDTH * HEIGHT * CHANNELS
32
+
33
+ # Network Varibles and placeholders
34
+ X = tf .placeholder (tf .float64 , [None , NUM_INPUTS ]) # Input
35
+ Y = tf .placeholder (tf .float64 , [None , NUM_OUTPUTS ]) # Truth Data - Output
36
+
37
+ # Network Architecture
38
+ def simple_net (x ):
39
+ he_init = tf .contrib .layers .variance_scaling_initializer ()
40
+ fc1 = tf .layers .dense (x , 128 , activation = tf .nn .relu , kernel_initializer = he_init , name = 'fc1' )
41
+ fc2 = tf .layers .dense (fc1 , NUM_OUTPUTS , activation = None , kernel_initializer = he_init , name = 'fc2' )
42
+ return fc2
43
+
44
+ # Define loss and optimizer
45
+ prediction = simple_net (X ) #unet(X)
46
+ loss = tf .reduce_mean (tf .square (prediction - Y ))
47
+ optimizer = tf .train .AdamOptimizer (learning_rate = learning_rate )
48
+ trainer = optimizer .minimize (loss )
49
+
50
+ # Initalize varibles, and run network
51
+ init = tf .global_variables_initializer ()
52
+ sess = tf .Session ()
53
+ sess .run (init )
54
+
55
+ print ('Start Training: BatchSize:' , batch_size ,' LearningRate:' , learning_rate )
56
+
57
+ for step in range (num_steps ):
58
+ x , y = data .get ()
59
+ sess .run (trainer , feed_dict = {X :x , Y :y })
60
+
61
+ if (step % display_step == 0 ):
62
+ _loss = sess .run (loss , feed_dict = { X :x , Y :y })
63
+ print ("Step: " + str (step ) + " Loss: " + str (_loss ))
64
+
65
+ x , y = data .get ()
66
+ img = sess .run (prediction , feed_dict = {X :x })
67
+ img = np .reshape (img , (data .WIDTH , data .HEIGHT , data .CHANNELS ))
68
+ plt .imshow (img [:,:,0 ], cmap = "gray" )
69
+ plt .show ()
0 commit comments