-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdream_net.py
71 lines (42 loc) · 2.09 KB
/
dream_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import numpy as np
from PIL import Image
import tensorflow as tf
import tensorflow_hub as hub
os.environ["TFHUB_MODEL_LOAD_FORMAT"] = "COMPRESSED"
hub_model_one = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/1")
hub_model_two = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
content_image = ""
style_image = ""
print("\nPlease provide the path to the desired content image (including the extension).\n")
while not os.path.isfile(content_image):
content_image = input()
if not os.path.isfile(content_image):
print("\nThe path you provided does not exist, please try again.\n")
print("\nPlease provide the path to the desired style image (including the extension).\n")
while not os.path.isfile(style_image):
style_image = input()
if not os.path.isfile(style_image):
print("\nThe path you provided does not exist, please try again.\n")
max_size = None
print("\nPlease provide the desired size of the output image (about 500px is recommended).\n")
while not type(max_size) == int:
try:
max_size = int(input())
except:
print("\nThe input you provided is not an integer, please try again.\n")
def load_image(path_to_image):
def resize_image(image):
shape = tf.cast(tf.shape(image)[:-1], tf.float32)
scale = max_size / max(shape)
new_shape = tf.cast(shape * scale, tf.int32)
return tf.image.resize(image, new_shape)[tf.newaxis, :]
image = tf.io.read_file(path_to_image)
image = tf.image.decode_image(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
return resize_image(image)
stylized_image_one = hub_model_one(tf.constant(load_image(content_image)), tf.constant(load_image(style_image)))[0]
stylized_image_two = hub_model_two(tf.constant(load_image(content_image)), tf.constant(load_image(style_image)))[0]
Image.fromarray(np.array(stylized_image_one * 255, dtype=np.uint8)[0]).save("stylized_image_one.jpg")
Image.fromarray(np.array(stylized_image_two * 255, dtype=np.uint8)[0]).save("stylized_image_two.jpg")
print("\nDone!\n")