Skip to content

Commit bcd0df8

Browse files
committed
new code
1 parent 98903ca commit bcd0df8

20 files changed

+4254
-0
lines changed

Real/configs/config_cub_18.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from easydict import EasyDict
2+
import numpy as np
3+
import pickle
4+
5+
6+
def read_pickle(name):
7+
with open(name, "rb") as f:
8+
data = pickle.load(f)
9+
return data
10+
11+
12+
# load/output dir
13+
opt = EasyDict()
14+
opt.loadf = "./dump"
15+
opt.outf = "./dump"
16+
opt.outr = './data/test result'
17+
18+
# normalize each data domain
19+
# opt.normalize_domain = False
20+
opt.print_switch = True
21+
opt.adj_default = True
22+
# now it is half circle
23+
opt.num_domain = 18
24+
opt.full_domain = list(range(opt.num_domain))
25+
opt.src_domain = [0,1,2,3,4,5,6,7,8] # [0,1,2,3,10,11,12,13]
26+
opt.tgt_domain = list(set(opt.full_domain) - set(opt.src_domain))
27+
opt.num_source = len(opt.src_domain)
28+
opt.num_target = opt.num_domain - opt.num_source
29+
opt.src_dmn_num = opt.num_source
30+
opt.tgt_dmn_num = opt.num_target
31+
opt.all_domain = opt.src_domain + opt.tgt_domain
32+
33+
opt.test_on_all_dmn = True
34+
35+
opt.sample_neighbour = False
36+
37+
38+
#opt.model = "DANN"
39+
#opt.model = "CDANN"
40+
#opt.model = "ADDA"
41+
opt.model = 'TSDA'
42+
#opt.model = "GDA"
43+
opt.cond_disc = (
44+
False # whether use conditional discriminator or not (for CDANN)
45+
)
46+
47+
48+
opt.use_visdom = False
49+
opt.visdom_port = 2000
50+
51+
opt.use_g_encode = False # False # True
52+
if opt.use_g_encode:
53+
opt.g_encode = read_pickle("g_encode_l7l40.pkl")
54+
55+
56+
opt.device = "cuda"
57+
opt.seed = 2333 # 1# 101 # 1 # 233 # 1
58+
59+
# opt.lambda_gan = 0.5 # 0.5 # 0.3125 # 0.5 # 0.5
60+
opt.lambda_gan = 0
61+
62+
# for MDD use only
63+
opt.lambda_src = 0.5
64+
opt.lambda_tgt = 0.5
65+
66+
# for TDDA use only
67+
opt.lambda_r = 0.594795294351678
68+
opt.lambda_d = 0.5
69+
opt.lambda_e = 0.7786359421472335
70+
opt.lambda_c = 1
71+
opt.num_epoch = 300
72+
opt.batch_size = 5
73+
opt.lr_d = 1e-5 # 3e-5 # 1e-4 # 2.9 * 1e-5 #3e-5 # 1e-4
74+
opt.lr_e = 1e-5 # 3e-5 # 1e-4 # 2.9 * 4e-6
75+
opt.lr_g = 1e-4
76+
opt.lr_r = 1e-4
77+
opt.gamma = 100
78+
opt.beta1 = 0.9
79+
opt.weight_decay = 5e-4
80+
opt.wgan = False # do not use wgan to train
81+
opt.no_bn = True # do not use batch normalization # True
82+
83+
# model size configs, used for D, E, F
84+
opt.nt = 2 # dimension of the vertex embedding
85+
opt.nc = 2 # number of label class
86+
opt.nd_out = 2 # dimension of D's output
87+
opt.nr_out = 2
88+
opt.num_input = 4096 # the x data dimension
89+
opt.nh = 4096 # TODO: the hidden states for many modules, be careful
90+
opt.nv_embed = 2 # the vertex embedding dimension
91+
# sample how many vertices for training R
92+
opt.sample_v = opt.num_domain
93+
94+
# # sample how many vertices for training G
95+
opt.sample_v_g = opt.num_domain
96+
97+
opt.test_interval = 20
98+
opt.save_interval = 100
99+
# drop out rate
100+
opt.p = 0.2
101+
opt.shuffle = True
102+
103+
104+
# dataset
105+
opt.data_src = "data/"
106+
opt.data_path = opt.data_src + "feature_upperparts_black.pkl"
107+
opt.dataset = opt.data_path
108+
opt.A = read_pickle(opt.data_src + "A_cub_18.pkl")

Real/configs/config_imagenet_11.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from easydict import EasyDict
2+
import numpy as np
3+
import pickle
4+
5+
6+
def read_pickle(name):
7+
with open(name, "rb") as f:
8+
data = pickle.load(f)
9+
return data
10+
11+
12+
# load/output dir
13+
opt = EasyDict()
14+
opt.loadf = "./dump"
15+
opt.outf = "./dump"
16+
opt.outr = './data/test result'
17+
18+
# normalize each data domain
19+
# opt.normalize_domain = False
20+
opt.print_switch = True
21+
opt.adj_default = True
22+
# now it is half circle
23+
opt.num_domain = 11
24+
opt.full_domain = list(range(opt.num_domain))
25+
opt.src_domain = [6,7,8,9,10]
26+
opt.tgt_domain = list(set(opt.full_domain) - set(opt.src_domain))
27+
opt.num_source = len(opt.src_domain)
28+
opt.num_target = opt.num_domain - opt.num_source
29+
opt.src_dmn_num = opt.num_source
30+
opt.tgt_dmn_num = opt.num_target
31+
opt.all_domain = opt.src_domain + opt.tgt_domain
32+
33+
opt.test_on_all_dmn = True
34+
35+
opt.sample_neighbour = False
36+
37+
38+
#opt.model = "DANN"
39+
#opt.model = "CDANN"
40+
#opt.model = "ADDA"
41+
opt.model = 'TSDA'
42+
#opt.model = "GDA"
43+
opt.cond_disc = (
44+
False # whether use conditional discriminator or not (for CDANN)
45+
)
46+
47+
48+
opt.use_visdom = False
49+
opt.visdom_port = 2000
50+
51+
opt.use_g_encode = False # False # True
52+
if opt.use_g_encode:
53+
opt.g_encode = read_pickle("g_encode_l7l40.pkl")
54+
55+
56+
opt.device = "cuda"
57+
opt.seed = 2333 # 1# 101 # 1 # 233 # 1
58+
59+
# opt.lambda_gan = 0.5 # 0.5 # 0.3125 # 0.5 # 0.5
60+
opt.lambda_gan = 0
61+
62+
# for MDD use only
63+
opt.lambda_src = 0.5
64+
opt.lambda_tgt = 0.5
65+
66+
# for TDDA use only
67+
opt.lambda_r = 0.594795294351678
68+
opt.lambda_d = 0.2488505602149487
69+
opt.lambda_e = 0.7786359421472335
70+
opt.lambda_c = 1
71+
72+
opt.num_epoch = 300
73+
opt.batch_size = 5
74+
opt.lr_d = 1e-5 # 3e-5 # 1e-4 # 2.9 * 1e-5 #3e-5 # 1e-4
75+
opt.lr_e = 1e-5 # 3e-5 # 1e-4 # 2.9 * 4e-6
76+
opt.lr_g = 1e-4
77+
opt.lr_r = 1e-4
78+
opt.gamma = 100
79+
opt.beta1 = 0.9
80+
opt.weight_decay = 5e-4
81+
opt.wgan = False # do not use wgan to train
82+
opt.no_bn = True # do not use batch normalization # True
83+
84+
# model size configs, used for D, E, F
85+
opt.nt = 2 # dimension of the vertex embedding
86+
opt.nc = 2 # number of label class
87+
opt.nd_out = 2 # dimension of D's output
88+
opt.nr_out = 2
89+
opt.num_input = 4096 # the x data dimension
90+
opt.nh = 4096 # TODO: the hidden states for many modules, be careful
91+
opt.nv_embed = 2 # the vertex embedding dimension
92+
# sample how many vertices for training R
93+
opt.sample_v = opt.num_domain
94+
95+
# # sample how many vertices for training G
96+
opt.sample_v_g = opt.num_domain
97+
98+
opt.test_interval = 20
99+
opt.save_interval = 100
100+
# drop out rate
101+
opt.p = 0.2
102+
opt.shuffle = True
103+
104+
105+
# dataset
106+
opt.data_src = "data/"
107+
opt.data_path = opt.data_src + "feature_brown.pkl"
108+
opt.dataset = opt.data_path
109+
opt.A_root = read_pickle(opt.data_src + "A_brown_root.pkl")
110+
opt.A = read_pickle(opt.data_src + "A_brown.pkl")

Real/data/A_brown.pkl

1.09 KB
Binary file not shown.

Real/data/A_brown_grda.pkl

1.09 KB
Binary file not shown.

Real/data/A_brown_root.pkl

1.09 KB
Binary file not shown.

Real/data/A_cub_18.pkl

2.68 KB
Binary file not shown.

Real/data/A_cub_grda.pkl

2.68 KB
Binary file not shown.

Real/data/feature_brown.pkl

3.66 MB
Binary file not shown.
16.2 MB
Binary file not shown.

Real/data/feature_wing_black.pkl

16.2 MB
Binary file not shown.

Real/dataset_utils/dataset.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy as np
2+
from torch.utils.data import Dataset
3+
import pickle
4+
5+
6+
def read_pickle(name):
7+
with open(name, "rb") as f:
8+
data = pickle.load(f)
9+
return data
10+
11+
12+
def write_pickle(data, name):
13+
with open(name, "wb") as f:
14+
pickle.dump(data, f)
15+
16+
17+
class ToyDataset(Dataset):
18+
def __init__(self, pkl, domain_id, opt=None):
19+
idx = pkl["domain"] == domain_id
20+
self.data = pkl["data"][idx].astype(np.float32)
21+
self.label = pkl["label"][idx].astype(np.int64)
22+
self.domain = domain_id
23+
24+
# if opt.normalize_domain:
25+
# print('===> Normalize in every domain')
26+
# self.data_m, self.data_s = self.data.mean(0, keepdims=True), self.data.std(0, keepdims=True)
27+
# self.data = (self.data - self.data_m) / self.data_s
28+
29+
def __getitem__(self, idx):
30+
return self.data[idx], self.label[idx], self.domain
31+
32+
def __len__(self):
33+
return len(self.data)
34+
35+
36+
class SeqToyDataset(Dataset):
37+
def __init__(self, datasets, size=3 * 200):
38+
self.datasets = datasets
39+
self.size = size
40+
print(
41+
"SeqDataset Size {} Sub Size {}".format(
42+
size, [len(ds) for ds in datasets]
43+
)
44+
)
45+
46+
def __len__(self):
47+
return self.size
48+
49+
def __getitem__(self, i):
50+
return [ds[i] for ds in self.datasets]

Real/dataset_utils/feature_dataset.py

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import numpy as np
2+
from torch.utils.data import DataLoader, Dataset
3+
import pickle
4+
5+
6+
def read_pickle(name):
7+
with open(name, 'rb') as f:
8+
data = pickle.load(f)
9+
return data
10+
11+
12+
class FeatureDataset(Dataset):
13+
def __init__(self, pkl, domain_id, sudo_len, opt=None):
14+
idx = pkl['domain'] == domain_id
15+
self.data = pkl['data'][idx].astype(np.float32)
16+
self.label = pkl['label'][idx].astype(np.int64)
17+
self.domain = domain_id
18+
self.real_len = len(self.data)
19+
self.sudo_len = sudo_len
20+
21+
# if opt.normalize_domain:
22+
# print('===> Normalize in every domain')
23+
# self.data_m, self.data_s = self.data.mean(0, keepdims=True), self.data.std(0, keepdims=True)
24+
# self.data = (self.data - self.data_m) / self.data_s
25+
26+
def __getitem__(self, idx):
27+
idx %= self.real_len
28+
return self.data[idx], self.label[idx], self.domain
29+
30+
def __len__(self):
31+
# return len(self.data)
32+
return self.sudo_len
33+
34+
35+
class FeatureDataloader(DataLoader):
36+
def __init__(self, opt):
37+
self.opt = opt
38+
self.src_domain = opt.src_domain
39+
self.tgt_domain = opt.tgt_domain
40+
self.all_domain = opt.all_domain
41+
42+
self.pkl = read_pickle(opt.data_path)
43+
sudo_len = 0
44+
for i in self.all_domain:
45+
idx = self.pkl['domain'] == i
46+
sudo_len = max(sudo_len, idx.sum())
47+
self.sudo_len = sudo_len
48+
49+
print("sudo len: {}".format(sudo_len))
50+
51+
self.train_datasets = [
52+
FeatureDataset(
53+
self.pkl,
54+
domain_id=i,
55+
opt=opt,
56+
sudo_len=self.sudo_len,
57+
) for i in self.all_domain
58+
]
59+
60+
if self.opt.test_on_all_dmn:
61+
self.test_datasets = [
62+
FeatureDataset(
63+
self.pkl,
64+
domain_id=i,
65+
opt=opt,
66+
sudo_len=self.sudo_len,
67+
) for i in self.all_domain
68+
]
69+
else:
70+
self.test_datasets = [
71+
FeatureDataset(
72+
self.pkl,
73+
domain_id=i,
74+
opt=opt,
75+
sudo_len=self.sudo_len,
76+
) for i in self.tgt_domain
77+
]
78+
79+
self.train_data_loader = [
80+
DataLoader(dataset,
81+
batch_size=opt.batch_size,
82+
shuffle=opt.shuffle,
83+
# consider if necessary
84+
num_workers=0,
85+
pin_memory=True,
86+
#persistent_workers=True
87+
)
88+
for dataset in self.train_datasets
89+
]
90+
91+
self.test_data_loader = [
92+
DataLoader(dataset,
93+
batch_size=opt.batch_size,
94+
shuffle=opt.shuffle,
95+
num_workers=0,
96+
pin_memory=True,
97+
#persistent_workers=True
98+
)
99+
for dataset in self.test_datasets
100+
]
101+
102+
def get_train_data(self):
103+
# this is return a iterator for the whole dataset
104+
return zip(*self.train_data_loader)
105+
106+
def get_test_data(self):
107+
return zip(*self.test_data_loader)
108+
109+
# class SeqToyDataset(Dataset):
110+
# # the size may change because of the toy dataset!!
111+
# def __init__(self, datasets, size=3 * 200):
112+
# self.datasets = datasets
113+
# self.size = size
114+
# print('SeqDataset Size {} Sub Size {}'.format(
115+
# size, [len(ds) for ds in datasets]
116+
# ))
117+
118+
# def __len__(self):
119+
# return self.size
120+
121+
# def __getitem__(self, i):
122+
# return [ds[i] for ds in self.datasets]

0 commit comments

Comments
 (0)