1
1
package advisor .client ;
2
2
3
3
import java .io .IOException ;
4
+ import java .io .UnsupportedEncodingException ;
5
+ import java .net .URLEncoder ;
6
+ import java .nio .charset .Charset ;
4
7
import java .text .ParseException ;
5
8
import java .util .HashMap ;
6
9
import java .util .Iterator ;
@@ -66,6 +69,38 @@ public Study createStudy(String name, JSONObject studyConfiguration, String algo
66
69
return null ;
67
70
}
68
71
72
+ public Study getOrCreateStudy (String studyName , JSONObject studyConfiguration , String algorithm ) throws ClientProtocolException , IOException , JSONException , ParseException {
73
+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
74
+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/exist" ;
75
+
76
+ CloseableHttpClient httpclient = HttpClients .createDefault ();
77
+
78
+ HttpGet httpGet = new HttpGet (url );
79
+ httpGet .setHeader ("Accept" , "application/json" );
80
+ httpGet .setHeader ("Content-type" , "application/json" );
81
+
82
+ CloseableHttpResponse responseHttp = httpclient .execute (httpGet );
83
+ Study study =null ;
84
+ if (responseHttp .getStatusLine ().getStatusCode ()==200 ){
85
+ String response =EntityUtils .toString (responseHttp .getEntity ()) ;
86
+ JSONObject responseJSON = new JSONObject (response );
87
+ boolean responseExists =responseJSON .getBoolean ("exist" );
88
+
89
+
90
+ if (responseExists ){
91
+ study =this .getStudyByName (studyName );
92
+ }
93
+ else {
94
+ study =this .createStudy (studyName , studyConfiguration , algorithm );
95
+ }
96
+
97
+
98
+
99
+ }
100
+
101
+ return study ;
102
+ }
103
+
69
104
public List <Study > listStudies () throws ClientProtocolException , IOException , JSONException , ParseException {
70
105
71
106
LinkedList <Study > list = new LinkedList <Study >();
@@ -93,8 +128,9 @@ public List<Study> listStudies() throws ClientProtocolException, IOException, JS
93
128
94
129
}
95
130
96
- public Study getStudyById (int studyId ) throws ClientProtocolException , IOException , JSONException , ParseException {
97
- String url = endpoint +"/suggestion/v1/studies/" +studyId ;
131
+ public Study getStudyByName (String studyName ) throws ClientProtocolException , IOException , JSONException , ParseException {
132
+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
133
+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL ;
98
134
CloseableHttpClient httpclient = HttpClients .createDefault ();
99
135
100
136
HttpGet httpGet = new HttpGet (url );
@@ -112,11 +148,12 @@ public Study getStudyById(int studyId) throws ClientProtocolException, IOExcepti
112
148
113
149
}
114
150
115
- public List <Trial > getSuggestions (int studyId , int trialsNumber ) throws ClientProtocolException , IOException , JSONException , ParseException {
151
+ public List <Trial > getSuggestions (String studyName , int trialsNumber ) throws ClientProtocolException , IOException , JSONException , ParseException {
152
+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
116
153
LinkedList <Trial > list = new LinkedList <Trial >();
117
154
if (trialsNumber <=0 )
118
155
trialsNumber =1 ;
119
- String url = endpoint +"/suggestion/v1/studies/" +studyId +"/suggestions" ;
156
+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/suggestions" ;
120
157
JSONObject requestData = new JSONObject ();
121
158
122
159
requestData .put ("trials_number" , trialsNumber );
@@ -144,8 +181,9 @@ public List<Trial> getSuggestions(int studyId, int trialsNumber) throws ClientPr
144
181
}
145
182
146
183
147
- public boolean isStudyDone (int studyId ) throws ClientProtocolException , JSONException , IOException , ParseException {
148
- Study study =getStudyById (studyId );
184
+ public boolean isStudyDone (String studyName ) throws ClientProtocolException , JSONException , IOException , ParseException {
185
+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
186
+ Study study =getStudyByName (studyName );
149
187
150
188
if (study ==null )
151
189
return false ;
@@ -154,12 +192,12 @@ public boolean isStudyDone(int studyId) throws ClientProtocolException, JSONExce
154
192
if (Study .COMPLETED .equals (study .getStatus ()))
155
193
return true ;
156
194
157
- List <Trial > trials =listTrials (studyId );
195
+ List <Trial > trials =listTrials (studyName );
158
196
for (Trial trial : trials ){
159
197
if (!Trial .COMPLETED .equals (trial .getStatus ()))
160
198
return false ;
161
199
}
162
- String url = endpoint +"/suggestion/v1/studies/" +studyId ;
200
+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL ;
163
201
164
202
JSONObject putData = new JSONObject ();
165
203
putData .put ("status" , "Completed" );
@@ -179,9 +217,10 @@ public boolean isStudyDone(int studyId) throws ClientProtocolException, JSONExce
179
217
return true ;
180
218
}
181
219
182
- public List <Trial > listTrials (int studyId ) throws ClientProtocolException , IOException , JSONException , ParseException {
220
+ public List <Trial > listTrials (String studyName ) throws ClientProtocolException , IOException , JSONException , ParseException {
221
+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
183
222
LinkedList <Trial > list = new LinkedList <Trial >();
184
- String url = endpoint +"/suggestion/v1/studies/" +studyId +"/trials" ;
223
+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/trials" ;
185
224
CloseableHttpClient httpclient = HttpClients .createDefault ();
186
225
187
226
HttpGet httpGet = new HttpGet (url );
@@ -202,19 +241,44 @@ public List<Trial> listTrials(int studyId) throws ClientProtocolException, IOExc
202
241
return list ;
203
242
}
204
243
205
- @ Beta
206
- public List <TrialMetric > listTrialMetrics (int studyId , int trialId ){
207
- return null ;
244
+
245
+ public List <TrialMetric > listTrialMetrics (String studyName , int trialId ) throws ClientProtocolException , IOException , JSONException , ParseException {
246
+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
247
+ LinkedList <TrialMetric > list = new LinkedList <TrialMetric >();
248
+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/trials/" +trialId +"/metrics" ;
249
+
250
+
251
+ CloseableHttpClient httpclient = HttpClients .createDefault ();
252
+
253
+ HttpGet httpGet = new HttpGet (url );
254
+ httpGet .setHeader ("Accept" , "application/json" );
255
+ httpGet .setHeader ("Content-type" , "application/json" );
256
+
257
+ CloseableHttpResponse responseHttp = httpclient .execute (httpGet );
258
+
259
+ if (responseHttp .getStatusLine ().getStatusCode ()==200 ){
260
+ String response =EntityUtils .toString (responseHttp .getEntity ()) ;
261
+ JSONObject responseJSON = new JSONObject (response );
262
+ JSONArray responseJSONData =responseJSON .getJSONArray ("data" );
263
+ Iterator <Object > trials =responseJSONData .iterator ();
264
+ while (trials .hasNext ()){
265
+ list .add (new TrialMetric ((JSONObject ) trials .next ()));
266
+ }
267
+
268
+ }
269
+
270
+ return list ;
208
271
}
209
272
210
- public Trial getBestTrial (int studyId ) throws ClientProtocolException , JSONException , IOException , ParseException {
211
- if (!this .isStudyDone (studyId ))
273
+ public Trial getBestTrial (String studyName ) throws ClientProtocolException , JSONException , IOException , ParseException {
274
+
275
+ if (!this .isStudyDone (studyName ))
212
276
return null ;
213
277
214
- Study st = this .getStudyById ( studyId );
278
+ Study st = this .getStudyByName ( studyName );
215
279
JSONObject configuration =st .getStudy_configuration ();
216
280
String studyGoal =configuration .getString ("goal" );
217
- List <Trial > trials =this .listTrials (studyId );
281
+ List <Trial > trials =this .listTrials (studyName );
218
282
219
283
Trial bestTrial =null ;
220
284
for (Trial trial : trials ) {
@@ -239,8 +303,9 @@ else if (studyGoal.equals(Study.MINIMIZE)) {
239
303
return bestTrial ;
240
304
}
241
305
242
- public Trial getTrial (int studyId , int trialId ) throws ClientProtocolException , IOException , JSONException , ParseException {
243
- String url = endpoint +"/suggestion/v1/studies/" +studyId +"/trials/" +trialId ;
306
+ public Trial getTrial (String studyName , int trialId ) throws ClientProtocolException , IOException , JSONException , ParseException {
307
+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
308
+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/trials/" +trialId ;
244
309
245
310
CloseableHttpClient httpclient = HttpClients .createDefault ();
246
311
@@ -258,8 +323,9 @@ public Trial getTrial(int studyId, int trialId) throws ClientProtocolException,
258
323
return null ;
259
324
}
260
325
261
- public TrialMetric createTrialMetric (int studyId , int trialId , JSONObject trainingStep , Double objectiveValue ) throws ClientProtocolException , IOException , JSONException , ParseException {
262
- String url = endpoint +"/suggestion/v1/studies/" +studyId +"/trials/" +trialId +"/metrics" ;
326
+ public TrialMetric createTrialMetric (String studyName , int trialId , JSONObject trainingStep , Double objectiveValue ) throws ClientProtocolException , IOException , JSONException , ParseException {
327
+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
328
+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/trials/" +trialId +"/metrics" ;
263
329
264
330
JSONObject requestData = new JSONObject ();
265
331
requestData .put ("training_step" , trainingStep ==null ?JSONObject .NULL :trainingStep );
@@ -284,13 +350,15 @@ public TrialMetric createTrialMetric(int studyId, int trialId, JSONObject traini
284
350
}
285
351
286
352
public Trial completeTrialWithTensorboardMetrics (Trial trial , LinkedList <HashMap <String ,Double >> list ) throws ClientProtocolException , JSONException , IOException , ParseException {
353
+
287
354
double objectiveValue =0 ;
288
355
for (HashMap <String ,Double > scalarSummary : list ) {
289
356
objectiveValue =scalarSummary .get ("value" );
290
- this .createTrialMetric (trial .getStudyId (), trial .getId (), new JSONObject (scalarSummary .get ("step" )), objectiveValue );
357
+ this .createTrialMetric (trial .getName (), trial .getId (), new JSONObject (scalarSummary .get ("step" )), objectiveValue );
291
358
292
359
}
293
- String url = endpoint +"/suggestion/v1/studies/" +trial .getStudyId ()+"/trials/" +trial .getId ();
360
+ String studyNameURL =URLEncoder .encode (trial .getStudyName (),"UTF-8" );
361
+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/trials/" +trial .getId ();
294
362
JSONObject putData = new JSONObject ();
295
363
putData .put ("status" , "Completed" );
296
364
putData .put ("objective_value" , objectiveValue );
@@ -313,8 +381,9 @@ public Trial completeTrialWithTensorboardMetrics(Trial trial, LinkedList<HashMap
313
381
}
314
382
315
383
public Trial completeTrialWithOneMetric (Trial trial , Double metric ) throws ClientProtocolException , JSONException , IOException , ParseException {
316
- this .createTrialMetric (trial .getStudyId (), trial .getId (), null , metric );
317
- String url = endpoint +"/suggestion/v1/studies/" +trial .getStudyId ()+"/trials/" +trial .getId ();
384
+ this .createTrialMetric (trial .getStudyName (), trial .getId (), null , metric );
385
+ String studyNameURL =URLEncoder .encode (trial .getStudyName (),"UTF-8" );
386
+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/trials/" +trial .getId ();
318
387
JSONObject putData = new JSONObject ();
319
388
putData .put ("status" , "Completed" );
320
389
putData .put ("objective_value" , metric );
@@ -341,7 +410,7 @@ public static void main(String[] args) {
341
410
AdvisorClient cl = new AdvisorClient ();
342
411
try {
343
412
JSONObject studyConfiguration = new JSONObject ("{\" goal\" : \" MAXIMIZE\" , \" randomInitTrials\" : 3, \" maxTrials\" : 5, \" maxParallelTrials\" : 1, \" params\" : [ { \" parameterName\" : \" hidden1\" , \" type\" : \" INTEGER\" , \" minValue\" : 1, \" maxValue\" : 10, \" scallingType\" : \" LINEAR\" }, {\" parameterName\" : \" learning_rate\" ,\" type\" : \" DOUBLE\" , \" minValue\" : 0.01, \" maxValue\" : 0.5, \" scallingType\" : \" LINEAR\" } ]}" );
344
- Study st =cl .createStudy ("StudyPruebaTest" , studyConfiguration , Study .BAYESIAN_OPTIMIZATION );
413
+ Study st =cl .getOrCreateStudy ("StudyPruebaTest" , studyConfiguration , Study .BAYESIAN_OPTIMIZATION );
345
414
System .out .println (st .getAlgorithm ());
346
415
347
416
} catch (Exception e ) {
0 commit comments