Skip to content

Commit be10a48

Browse files
Added learning fft model example
1 parent 8b7a642 commit be10a48

File tree

4 files changed

+153
-0
lines changed

4 files changed

+153
-0
lines changed

examples/fft/data_loader.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
import os
6+
7+
import numpy as np
8+
import matplotlib.pyplot as plt
9+
import matplotlib.image as mpimg
10+
11+
class DataLoader:
12+
13+
def __init__(self):
14+
self.load()
15+
16+
def get(self):
17+
input = np.reshape(self.input, (-1, self.WIDTH * self.HEIGHT * 2))
18+
output = np.reshape(self.output, (-1, self.WIDTH * self.HEIGHT * 2))
19+
return input, output
20+
21+
def load(self):
22+
23+
filepath = '.'
24+
filename = 'shepp256.png'
25+
img = mpimg.imread(os.path.join(filepath, filename))
26+
img = np.array(img)
27+
28+
self.WIDTH = img.shape[0]
29+
self.HEIGHT = img.shape[1]
30+
self.CHANNELS = 2
31+
32+
kSpace = np.fft.ifftshift(np.fft.fft2(img))
33+
inverse = np.fft.ifft2(kSpace)
34+
35+
self.input = np.dstack((np.abs(kSpace), np.angle(kSpace)))
36+
self.output = np.dstack((np.abs(inverse), np.angle(inverse)))
37+
38+
def show(self):
39+
40+
input = self.input[:,:,0] * np.exp(1j*self.input[:,:,1])
41+
output = self.output[:,:,0] * np.exp(1j*self.output[:,:,1])
42+
43+
plt.subplot(2, 1, 1)
44+
plt.imshow(np.abs(input), cmap='gray')
45+
46+
plt.subplot(2, 1, 2)
47+
plt.imshow(np.abs(output), cmap='gray')
48+
plt.show()
49+
50+
def info(self):
51+
print(self.input.dtype)
52+
53+
if __name__ == '__main__':
54+
data = DataLoader()
55+
data.info()
56+
data.show()
57+
x, y = data.get()
58+
print(x.shape, y.shape)

examples/fft/fft.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
2+
'''
3+
Tensorflow Code for a fourier transform network
4+
'''
5+
6+
from __future__ import absolute_import
7+
from __future__ import division
8+
from __future__ import print_function
9+
10+
import tensorflow as tf
11+
import numpy as np
12+
import matplotlib
13+
import matplotlib.pyplot as plt
14+
import os
15+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
16+
17+
# Import Dataset
18+
from data_loader import DataLoader
19+
data = DataLoader()
20+
print('Data Loaded')
21+
22+
# Training Parameters
23+
learning_rate = 0.0001
24+
num_steps = 10000
25+
batch_size = 32
26+
display_step = 100
27+
28+
# Network Parameters
29+
WIDTH = data.WIDTH; HEIGHT = data.HEIGHT; CHANNELS = data.CHANNELS
30+
NUM_INPUTS = WIDTH * HEIGHT * CHANNELS
31+
NUM_OUTPUTS = WIDTH * HEIGHT * CHANNELS
32+
33+
# Network Varibles and placeholders
34+
X = tf.placeholder(tf.float64, [None, NUM_INPUTS]) # Input
35+
Y = tf.placeholder(tf.float64, [None, NUM_OUTPUTS]) # Truth Data - Output
36+
37+
# Network Architecture
38+
def simple_net(x):
39+
he_init = tf.contrib.layers.variance_scaling_initializer()
40+
fc1 = tf.layers.dense(x, 128, activation=tf.nn.relu, kernel_initializer=he_init, name='fc1')
41+
fc2 = tf.layers.dense(fc1, NUM_OUTPUTS, activation=None, kernel_initializer=he_init, name='fc2')
42+
return fc2
43+
44+
# Define loss and optimizer
45+
prediction = simple_net(X) #unet(X)
46+
loss = tf.reduce_mean(tf.square(prediction - Y))
47+
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
48+
trainer = optimizer.minimize(loss)
49+
50+
# Initalize varibles, and run network
51+
init = tf.global_variables_initializer()
52+
sess = tf.Session()
53+
sess.run(init)
54+
55+
print ('Start Training: BatchSize:', batch_size,' LearningRate:', learning_rate)
56+
57+
for step in range(num_steps):
58+
x, y = data.get()
59+
sess.run(trainer, feed_dict={X:x, Y:y})
60+
61+
if(step % display_step == 0):
62+
_loss = sess.run(loss, feed_dict={ X:x, Y:y })
63+
print("Step: " + str(step) + " Loss: " + str(_loss))
64+
65+
x, y = data.get()
66+
img = sess.run(prediction, feed_dict={X:x})
67+
img = np.reshape(img, (data.WIDTH, data.HEIGHT, data.CHANNELS))
68+
plt.imshow(img[:,:,0], cmap="gray")
69+
plt.show()

examples/fft/fft_example.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
import os
6+
7+
import numpy as np
8+
import matplotlib.pyplot as plt
9+
import matplotlib.image as mpimg
10+
11+
12+
filepath = '.'
13+
filename = 'shepp256.png'
14+
img = mpimg.imread(os.path.join(filepath, filename))
15+
img = np.array(img)
16+
17+
kSpace = np.fft.ifftshift(np.fft.fft2(img))
18+
inverse = np.fft.ifft2(kSpace)
19+
20+
plt.subplot(2, 1, 1)
21+
plt.imshow(np.abs(kSpace), cmap='gray')
22+
23+
plt.subplot(2, 1, 2)
24+
plt.imshow(np.abs(inverse), cmap='gray')
25+
plt.show()
26+

examples/fft/shepp256.png

2.18 KB
Loading

0 commit comments

Comments
 (0)