Skip to content

Commit 0c5daa3

Browse files
authored
Merge pull request #659 from prashantsail/dcgan_fashiongen_example
2 parents 411ab09 + 6903c29 commit 0c5daa3

File tree

8 files changed

+280
-1
lines changed

8 files changed

+280
-1
lines changed

examples/README.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
* [Serving object detection model](#serving-object-detection-model)
88
* [Serving image segmentation model](#serving-image-segmentation-model)
99
* [Serving huggingface transformers model](#serving-huggingface-transformers)
10+
* [Serving image generator model](#example-to-serve-GAN-model)
1011
* [Serving machine translation model](#serving-neural-machine-translation)
1112
* [Serving waveglow text to speech synthesizer model](#serving-wavegolw-text-to-speech-synthesizer)
1213
* [Serving multi modal framework model](#Serving-Multi-modal-model)
@@ -85,6 +86,12 @@ The following example demonstrates how to create and serve a pretrained transfor
8586
8687
* [Hugging Face Transformers](Huggingface_Transformers)
8788
89+
## Example to serve GAN model
90+
91+
The following example demonstrates how to create and serve a pretrained DCGAN model from [facebookresearch/pytorch_GAN_zoo](https://github.com/facebookresearch/pytorch_GAN_zoo)
92+
93+
* [GAN Image Generator](dcgan_fashiongen)
94+
8895
## Serving Neural Machine Translation
8996
9097
The following example demonstrates how to create and serve a neural translation model using fairseq
@@ -113,4 +120,4 @@ The following example demonstrates how to create and serve a complex image class
113120
114121
The following example demonstrates how to create and serve a complex neural machine translation workflow
115122
116-
* [Neural machine Translation workflow](Workflows/nmt_transformers_pipeline)
123+
* [Neural machine Translation workflow](Workflows/nmt_transformers_pipeline)

examples/dcgan_fashiongen/Readme.md

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# GAN(Generative Adversarial Networks) models using TorchServe
2+
- In this example we will demonstrate how to serve a GAN model using TorchServe.
3+
- We have used a pretrained DCGAN model from [facebookresearch/pytorch_GAN_zoo](https://github.com/facebookresearch/pytorch_GAN_zoo)
4+
(Introduction to [DCGAN on FashionGen](https://pytorch.org/hub/facebookresearch_pytorch-gan-zoo_dcgan/))
5+
6+
### 1. Create a Torch Model Archive
7+
Execute the following command to create _dcgan_fashiongen.mar_ :
8+
```
9+
./create_mar.sh
10+
```
11+
The [create_mar.sh](create_mar.sh) script does the following :
12+
- Download the model's [source code](https://github.com/facebookresearch/pytorch_GAN_zoo/tree/b75dee40918caabb4fe7ec561522717bf096a8cb/models), extract the relevant directory and zip it. (`--extra-files`)
13+
- Download a checkpoint file [DCGAN_fashionGen-1d67302.pth](https://dl.fbaipublicfiles.com/gan_zoo/DCGAN_fashionGen-1d67302.pth). (`--serialized-file`)
14+
- Provide a custom handler - [dcgan_fashiongen_handler.py](dcgan_fashiongen_handler.py). (`--handler`)
15+
16+
Alterantively, you can directly [download the dcgan_fashiongen.mar](https://torchserve.s3.amazonaws.com/mar_files/dcgan_fashiongen.mar)
17+
18+
### 2. Start TorchServe and Register Model
19+
```
20+
mkdir modelstore
21+
mv dcgan_fashiongen.mar modelstore/
22+
torchserve --start --ncs --model-store ./modelstore --models dcgan_fashiongen.mar
23+
```
24+
25+
### 3. Generate Images
26+
Invoke the predictions API and pass following payload(JSON)
27+
- **number_of_images** : Number of images to generate
28+
- **input_gender** : OPTIONAL; If specified, needs to be one of - `Men`, `Women`
29+
- **input_category** : OPTIONAL; If specified, needs to be one of - One of - `SHIRTS`, `SWEATERS`, `JEANS`, `PANTS`, `TOPS`, `SUITS & BLAZERS`, `SHORTS`, `JACKETS & COATS`, `SKIRTS`, `JUMPSUITS`, `SWIMWEAR`, `DRESSES`
30+
- **input_pose** : OPTIONAL; If specified, needs to be one of - `id_gridfs_1`, `id_gridfs_2`, `id_gridfs_3`, `id_gridfs_4`
31+
32+
#### Example
33+
1. **Create a single image (random gender, category, pose)**
34+
```
35+
curl -X POST -d '{"number_of_images":1}' -H "Content-Type: application/json" http://localhost:8080/predictions/dcgan_fashiongen -o img1.jpg
36+
```
37+
> Result image should be similar to the one below -
38+
> ![Sample Image 1](sample-output/img1.jpg)
39+
40+
2. **Create '64' images of 'Men' wearing 'Shirts' in 'id_gridfs_1' pose**
41+
```
42+
curl -X POST -d '{"number_of_images":64, "input_gender":"Men", "input_category":"SHIRTS", "input_pose":"id_gridfs_1"}' -H "Content-Type: application/json" http://localhost:8080/predictions/dcgan_fashiongen -o img2.jpg
43+
```
44+
> Result image should be similar to the one below -
45+
> ![Sample Image 2](sample-output/img2.jpg)
46+
47+
3. **Create '32' images of 'Women' wearing 'Dresses' in 'id_gridfs_3' pose**
48+
```
49+
curl -X POST -d '{"number_of_images":32, "input_gender":"Women", "input_category":"DRESSES", "input_pose":"id_gridfs_3"}' -H "Content-Type: application/json" http://localhost:8080/predictions/dcgan_fashiongen -o img3.jpg
50+
```
51+
> Result image should be similar to the one below -
52+
> ![Sample Image 3](sample-output/img3.jpg)
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#!/bin/bash
2+
3+
REPONAME="pytorch_GAN_zoo"
4+
SHA="b75dee40918caabb4fe7ec561522717bf096a8cb" #master branch as of 27th Aug 2020
5+
SRCZIP="$SHA.zip"
6+
SRCDIR="$REPONAME-$SHA"
7+
MODELSZIP="models.zip"
8+
MODELSDIR="models"
9+
CHECKPOINT="DCGAN_fashionGen-1d67302.pth" #The DCGAN pretrained model as of 27th Aug 2020
10+
CHECKPOINT_RENAMED="DCGAN_fashionGen.pth"
11+
12+
# Clean Up before exit
13+
function cleanup {
14+
rm -rf $SRCZIP $SRCDIR $MODELSZIP $CHECKPOINT_RENAMED $MODELSDIR
15+
}
16+
trap cleanup EXIT
17+
18+
# Download and Extract model's source code
19+
sudo apt-get install zip unzip
20+
21+
wget https://github.com/facebookresearch/pytorch_GAN_zoo/archive/$SRCZIP
22+
unzip $SRCZIP
23+
# Get the models directory from the source code and zip it up
24+
# This will later be used by torchserve for loading the model
25+
mv $SRCDIR/models .
26+
zip -r $MODELSZIP $MODELSDIR
27+
28+
# Download checkpoint
29+
wget https://dl.fbaipublicfiles.com/gan_zoo/$CHECKPOINT -O $CHECKPOINT_RENAMED
30+
31+
# Create *.mar
32+
torch-model-archiver --model-name dcgan_fashiongen \
33+
--version 1.0 \
34+
--serialized-file $CHECKPOINT_RENAMED \
35+
--handler dcgan_fashiongen_handler.py \
36+
--extra-files $MODELSZIP \
37+
--force
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import os
2+
import zipfile
3+
import torch
4+
from io import BytesIO
5+
from torchvision.utils import save_image
6+
from ts.torch_handler.base_handler import BaseHandler
7+
8+
MODELSZIP = "models.zip"
9+
CHECKPOINT = "DCGAN_fashionGen.pth"
10+
11+
12+
class ModelHandler(BaseHandler):
13+
14+
def __init__(self):
15+
self.initialized = False
16+
self.map_location = None
17+
self.device = None
18+
self.use_gpu = True
19+
self.store_avg = True
20+
self.dcgan_model = None
21+
self.default_number_of_images = 1
22+
23+
def initialize(self, context):
24+
"""
25+
Extract the models zip; Take the serialized file and load the model
26+
"""
27+
properties = context.system_properties
28+
model_dir = properties.get("model_dir")
29+
gpu_id = properties.get("gpu_id")
30+
31+
self.map_location, self.device, self.use_gpu = \
32+
("cuda", torch.device("cuda:"+str(gpu_id)), True) if torch.cuda.is_available() else \
33+
("cpu", torch.device("cpu"), False)
34+
35+
# If not already extracted, Extract model source code
36+
if not os.path.exists(model_dir + "/models"):
37+
with zipfile.ZipFile(model_dir + "/" + MODELSZIP, "r") as zip_ref:
38+
zip_ref.extractall(model_dir)
39+
40+
# Load Model
41+
from models.DCGAN import DCGAN
42+
self.dcgan_model = DCGAN(useGPU=self.use_gpu, storeAVG=self.store_avg)
43+
state_dict = torch.load(model_dir + "/" + CHECKPOINT, map_location=self.map_location)
44+
self.dcgan_model.load_state_dict(state_dict)
45+
46+
self.initialized = True
47+
48+
def preprocess(self, requests):
49+
"""
50+
Build noise data by using "number of images" and other "constraints" provided by the end user.
51+
"""
52+
preprocessed_data = []
53+
for req in requests:
54+
data = req.get("data") if req.get("data") is not None else req.get("body", {})
55+
56+
number_of_images = data.get("number_of_images", self.default_number_of_images)
57+
labels = {ky: "b'{}'".format(vl) for ky, vl in data.items() if ky not in ["number_of_images"]}
58+
59+
noise = self.dcgan_model.buildNoiseDataWithConstraints(number_of_images, labels)
60+
preprocessed_data.append({
61+
"number_of_images": number_of_images,
62+
"input": noise
63+
})
64+
return preprocessed_data
65+
66+
def inference(self, preprocessed_data, *args, **kwargs):
67+
"""
68+
Take the noise data as an input tensor, pass it to the model and collect the output tensor.
69+
"""
70+
input_batch = torch.cat(tuple(map(lambda d: d["input"], preprocessed_data)), 0)
71+
with torch.no_grad():
72+
image_tensor = self.dcgan_model.test(input_batch, getAvG=True, toCPU=True)
73+
output_batch = torch.split(image_tensor, tuple(map(lambda d: d["number_of_images"], preprocessed_data)))
74+
return output_batch
75+
76+
def postprocess(self, output_batch):
77+
"""
78+
Create an image(jpeg) using the output tensor.
79+
"""
80+
postprocessed_data = []
81+
for op in output_batch:
82+
fp = BytesIO()
83+
save_image(op, fp, format="JPEG")
84+
postprocessed_data.append(fp.getvalue())
85+
fp.close()
86+
return postprocessed_data
1.13 KB
Loading
58.2 KB
Loading
25.2 KB
Loading

test/pytest/test_example_dcgan.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import os
2+
import subprocess
3+
import requests
4+
import test_utils
5+
6+
from shutil import copy
7+
from PIL import Image
8+
from io import BytesIO
9+
10+
CURR_FILE_PATH = os.path.dirname(os.path.realpath(__file__))
11+
REPO_ROOT_DIR = os.path.normpath(os.path.join(CURR_FILE_PATH, "..", ".."))
12+
MODELSTORE_DIR = os.path.join(REPO_ROOT_DIR, "modelstore")
13+
DCGAN_EXAMPLE_DIR = os.path.join(REPO_ROOT_DIR, "examples", "dcgan_fashiongen")
14+
DCGAN_MAR_FILE = os.path.join(DCGAN_EXAMPLE_DIR, "dcgan_fashiongen.mar")
15+
16+
17+
os.environ['MKL_THREADING_LAYER'] = 'GNU'
18+
# Work around for issue - https://github.com/pytorch/pytorch/issues/37377
19+
20+
def setup_module():
21+
test_utils.torchserve_cleanup()
22+
create_example_mar()
23+
24+
os.makedirs(MODELSTORE_DIR, exist_ok=True) # Create modelstore directory
25+
copy(DCGAN_MAR_FILE, MODELSTORE_DIR) # Copy *.mar to modelstore
26+
27+
test_utils.start_torchserve(model_store=MODELSTORE_DIR)
28+
pass
29+
30+
31+
def teardown_module():
32+
test_utils.torchserve_cleanup()
33+
34+
# Empty and Remove modelstore directory
35+
test_utils.delete_model_store(MODELSTORE_DIR)
36+
os.rmdir(MODELSTORE_DIR)
37+
38+
delete_example_mar()
39+
pass
40+
41+
42+
def create_example_mar():
43+
# Create only if not already present
44+
if not os.path.exists(DCGAN_MAR_FILE):
45+
create_mar_cmd = "cd " + DCGAN_EXAMPLE_DIR + ";./create_mar.sh"
46+
subprocess.check_call(create_mar_cmd, shell=True)
47+
48+
49+
def delete_example_mar():
50+
try:
51+
os.remove(DCGAN_MAR_FILE)
52+
except OSError:
53+
pass
54+
55+
56+
def test_model_archive_creation():
57+
# *.mar created in setup phase
58+
assert os.path.exists(DCGAN_MAR_FILE), "Failed to create dcgan mar file"
59+
60+
61+
def test_model_register_unregister():
62+
reg_resp = test_utils.register_model("dcgan_fashiongen", "dcgan_fashiongen.mar")
63+
assert reg_resp.status_code == 200, "Model Registration Failed"
64+
65+
unreg_resp = test_utils.unregister_model("dcgan_fashiongen")
66+
assert unreg_resp.status_code == 200, "Model Unregistration Failed"
67+
68+
69+
def test_image_generation_without_any_input_constraints():
70+
test_utils.register_model("dcgan_fashiongen", "dcgan_fashiongen.mar")
71+
input_json = {}
72+
response = requests.post(url="http://localhost:8080/predictions/dcgan_fashiongen", json=input_json)
73+
fp = BytesIO(response.content)
74+
img = Image.open(fp)
75+
# Expect a jpeg of dimension 64 x 64, it contains only 1 image by default
76+
assert response.status_code == 200, "Image generation failed"
77+
assert img.get_format_mimetype() == "image/jpeg", "Generated image is not a jpeg"
78+
assert img.size == (64, 64), "Generated image is not of correct dimensions"
79+
test_utils.unregister_model("dcgan_fashiongen")
80+
81+
82+
def test_image_generation_with_input_constraints():
83+
test_utils.register_model("dcgan_fashiongen", "dcgan_fashiongen.mar")
84+
input_json = {
85+
"number_of_images": 64,
86+
"input_gender": "Men",
87+
"input_category": "SHIRTS",
88+
"input_pose": "id_gridfs_1"
89+
}
90+
response = requests.post(url="http://localhost:8080/predictions/dcgan_fashiongen", json=input_json)
91+
fp = BytesIO(response.content)
92+
img = Image.open(fp)
93+
# Expect a jpeg of dimension 530 x 530, it contains 64 images
94+
assert response.status_code == 200, "Image generation failed"
95+
assert img.get_format_mimetype() == "image/jpeg", "Generated image is not a jpeg"
96+
assert img.size == (530, 530), "Generated image is not of correct dimensions"
97+
test_utils.unregister_model("dcgan_fashiongen")

0 commit comments

Comments
 (0)