11package advisor .client ;
22
33import java .io .IOException ;
4+ import java .io .UnsupportedEncodingException ;
5+ import java .net .URLEncoder ;
6+ import java .nio .charset .Charset ;
47import java .text .ParseException ;
58import java .util .HashMap ;
69import java .util .Iterator ;
@@ -66,6 +69,38 @@ public Study createStudy(String name, JSONObject studyConfiguration, String algo
6669 return null ;
6770 }
6871
72+ public Study getOrCreateStudy (String studyName , JSONObject studyConfiguration , String algorithm ) throws ClientProtocolException , IOException , JSONException , ParseException {
73+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
74+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/exist" ;
75+
76+ CloseableHttpClient httpclient = HttpClients .createDefault ();
77+
78+ HttpGet httpGet = new HttpGet (url );
79+ httpGet .setHeader ("Accept" , "application/json" );
80+ httpGet .setHeader ("Content-type" , "application/json" );
81+
82+ CloseableHttpResponse responseHttp = httpclient .execute (httpGet );
83+ Study study =null ;
84+ if (responseHttp .getStatusLine ().getStatusCode ()==200 ){
85+ String response =EntityUtils .toString (responseHttp .getEntity ()) ;
86+ JSONObject responseJSON = new JSONObject (response );
87+ boolean responseExists =responseJSON .getBoolean ("exist" );
88+
89+
90+ if (responseExists ){
91+ study =this .getStudyByName (studyName );
92+ }
93+ else {
94+ study =this .createStudy (studyName , studyConfiguration , algorithm );
95+ }
96+
97+
98+
99+ }
100+
101+ return study ;
102+ }
103+
69104 public List <Study > listStudies () throws ClientProtocolException , IOException , JSONException , ParseException {
70105
71106 LinkedList <Study > list = new LinkedList <Study >();
@@ -93,8 +128,9 @@ public List<Study> listStudies() throws ClientProtocolException, IOException, JS
93128
94129 }
95130
96- public Study getStudyById (int studyId ) throws ClientProtocolException , IOException , JSONException , ParseException {
97- String url = endpoint +"/suggestion/v1/studies/" +studyId ;
131+ public Study getStudyByName (String studyName ) throws ClientProtocolException , IOException , JSONException , ParseException {
132+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
133+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL ;
98134 CloseableHttpClient httpclient = HttpClients .createDefault ();
99135
100136 HttpGet httpGet = new HttpGet (url );
@@ -112,11 +148,12 @@ public Study getStudyById(int studyId) throws ClientProtocolException, IOExcepti
112148
113149 }
114150
115- public List <Trial > getSuggestions (int studyId , int trialsNumber ) throws ClientProtocolException , IOException , JSONException , ParseException {
151+ public List <Trial > getSuggestions (String studyName , int trialsNumber ) throws ClientProtocolException , IOException , JSONException , ParseException {
152+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
116153 LinkedList <Trial > list = new LinkedList <Trial >();
117154 if (trialsNumber <=0 )
118155 trialsNumber =1 ;
119- String url = endpoint +"/suggestion/v1/studies/" +studyId +"/suggestions" ;
156+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/suggestions" ;
120157 JSONObject requestData = new JSONObject ();
121158
122159 requestData .put ("trials_number" , trialsNumber );
@@ -144,8 +181,9 @@ public List<Trial> getSuggestions(int studyId, int trialsNumber) throws ClientPr
144181 }
145182
146183
147- public boolean isStudyDone (int studyId ) throws ClientProtocolException , JSONException , IOException , ParseException {
148- Study study =getStudyById (studyId );
184+ public boolean isStudyDone (String studyName ) throws ClientProtocolException , JSONException , IOException , ParseException {
185+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
186+ Study study =getStudyByName (studyName );
149187
150188 if (study ==null )
151189 return false ;
@@ -154,12 +192,12 @@ public boolean isStudyDone(int studyId) throws ClientProtocolException, JSONExce
154192 if (Study .COMPLETED .equals (study .getStatus ()))
155193 return true ;
156194
157- List <Trial > trials =listTrials (studyId );
195+ List <Trial > trials =listTrials (studyName );
158196 for (Trial trial : trials ){
159197 if (!Trial .COMPLETED .equals (trial .getStatus ()))
160198 return false ;
161199 }
162- String url = endpoint +"/suggestion/v1/studies/" +studyId ;
200+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL ;
163201
164202 JSONObject putData = new JSONObject ();
165203 putData .put ("status" , "Completed" );
@@ -179,9 +217,10 @@ public boolean isStudyDone(int studyId) throws ClientProtocolException, JSONExce
179217 return true ;
180218 }
181219
182- public List <Trial > listTrials (int studyId ) throws ClientProtocolException , IOException , JSONException , ParseException {
220+ public List <Trial > listTrials (String studyName ) throws ClientProtocolException , IOException , JSONException , ParseException {
221+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
183222 LinkedList <Trial > list = new LinkedList <Trial >();
184- String url = endpoint +"/suggestion/v1/studies/" +studyId +"/trials" ;
223+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/trials" ;
185224 CloseableHttpClient httpclient = HttpClients .createDefault ();
186225
187226 HttpGet httpGet = new HttpGet (url );
@@ -202,19 +241,44 @@ public List<Trial> listTrials(int studyId) throws ClientProtocolException, IOExc
202241 return list ;
203242 }
204243
205- @ Beta
206- public List <TrialMetric > listTrialMetrics (int studyId , int trialId ){
207- return null ;
244+
245+ public List <TrialMetric > listTrialMetrics (String studyName , int trialId ) throws ClientProtocolException , IOException , JSONException , ParseException {
246+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
247+ LinkedList <TrialMetric > list = new LinkedList <TrialMetric >();
248+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/trials/" +trialId +"/metrics" ;
249+
250+
251+ CloseableHttpClient httpclient = HttpClients .createDefault ();
252+
253+ HttpGet httpGet = new HttpGet (url );
254+ httpGet .setHeader ("Accept" , "application/json" );
255+ httpGet .setHeader ("Content-type" , "application/json" );
256+
257+ CloseableHttpResponse responseHttp = httpclient .execute (httpGet );
258+
259+ if (responseHttp .getStatusLine ().getStatusCode ()==200 ){
260+ String response =EntityUtils .toString (responseHttp .getEntity ()) ;
261+ JSONObject responseJSON = new JSONObject (response );
262+ JSONArray responseJSONData =responseJSON .getJSONArray ("data" );
263+ Iterator <Object > trials =responseJSONData .iterator ();
264+ while (trials .hasNext ()){
265+ list .add (new TrialMetric ((JSONObject ) trials .next ()));
266+ }
267+
268+ }
269+
270+ return list ;
208271 }
209272
210- public Trial getBestTrial (int studyId ) throws ClientProtocolException , JSONException , IOException , ParseException {
211- if (!this .isStudyDone (studyId ))
273+ public Trial getBestTrial (String studyName ) throws ClientProtocolException , JSONException , IOException , ParseException {
274+
275+ if (!this .isStudyDone (studyName ))
212276 return null ;
213277
214- Study st = this .getStudyById ( studyId );
278+ Study st = this .getStudyByName ( studyName );
215279 JSONObject configuration =st .getStudy_configuration ();
216280 String studyGoal =configuration .getString ("goal" );
217- List <Trial > trials =this .listTrials (studyId );
281+ List <Trial > trials =this .listTrials (studyName );
218282
219283 Trial bestTrial =null ;
220284 for (Trial trial : trials ) {
@@ -239,8 +303,9 @@ else if (studyGoal.equals(Study.MINIMIZE)) {
239303 return bestTrial ;
240304 }
241305
242- public Trial getTrial (int studyId , int trialId ) throws ClientProtocolException , IOException , JSONException , ParseException {
243- String url = endpoint +"/suggestion/v1/studies/" +studyId +"/trials/" +trialId ;
306+ public Trial getTrial (String studyName , int trialId ) throws ClientProtocolException , IOException , JSONException , ParseException {
307+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
308+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/trials/" +trialId ;
244309
245310 CloseableHttpClient httpclient = HttpClients .createDefault ();
246311
@@ -258,8 +323,9 @@ public Trial getTrial(int studyId, int trialId) throws ClientProtocolException,
258323 return null ;
259324 }
260325
261- public TrialMetric createTrialMetric (int studyId , int trialId , JSONObject trainingStep , Double objectiveValue ) throws ClientProtocolException , IOException , JSONException , ParseException {
262- String url = endpoint +"/suggestion/v1/studies/" +studyId +"/trials/" +trialId +"/metrics" ;
326+ public TrialMetric createTrialMetric (String studyName , int trialId , JSONObject trainingStep , Double objectiveValue ) throws ClientProtocolException , IOException , JSONException , ParseException {
327+ String studyNameURL =URLEncoder .encode (studyName ,"UTF-8" );
328+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/trials/" +trialId +"/metrics" ;
263329
264330 JSONObject requestData = new JSONObject ();
265331 requestData .put ("training_step" , trainingStep ==null ?JSONObject .NULL :trainingStep );
@@ -284,13 +350,15 @@ public TrialMetric createTrialMetric(int studyId, int trialId, JSONObject traini
284350 }
285351
286352 public Trial completeTrialWithTensorboardMetrics (Trial trial , LinkedList <HashMap <String ,Double >> list ) throws ClientProtocolException , JSONException , IOException , ParseException {
353+
287354 double objectiveValue =0 ;
288355 for (HashMap <String ,Double > scalarSummary : list ) {
289356 objectiveValue =scalarSummary .get ("value" );
290- this .createTrialMetric (trial .getStudyId (), trial .getId (), new JSONObject (scalarSummary .get ("step" )), objectiveValue );
357+ this .createTrialMetric (trial .getName (), trial .getId (), new JSONObject (scalarSummary .get ("step" )), objectiveValue );
291358
292359 }
293- String url = endpoint +"/suggestion/v1/studies/" +trial .getStudyId ()+"/trials/" +trial .getId ();
360+ String studyNameURL =URLEncoder .encode (trial .getStudyName (),"UTF-8" );
361+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/trials/" +trial .getId ();
294362 JSONObject putData = new JSONObject ();
295363 putData .put ("status" , "Completed" );
296364 putData .put ("objective_value" , objectiveValue );
@@ -313,8 +381,9 @@ public Trial completeTrialWithTensorboardMetrics(Trial trial, LinkedList<HashMap
313381 }
314382
315383 public Trial completeTrialWithOneMetric (Trial trial , Double metric ) throws ClientProtocolException , JSONException , IOException , ParseException {
316- this .createTrialMetric (trial .getStudyId (), trial .getId (), null , metric );
317- String url = endpoint +"/suggestion/v1/studies/" +trial .getStudyId ()+"/trials/" +trial .getId ();
384+ this .createTrialMetric (trial .getStudyName (), trial .getId (), null , metric );
385+ String studyNameURL =URLEncoder .encode (trial .getStudyName (),"UTF-8" );
386+ String url = endpoint +"/suggestion/v1/studies/" +studyNameURL +"/trials/" +trial .getId ();
318387 JSONObject putData = new JSONObject ();
319388 putData .put ("status" , "Completed" );
320389 putData .put ("objective_value" , metric );
@@ -341,7 +410,7 @@ public static void main(String[] args) {
341410 AdvisorClient cl = new AdvisorClient ();
342411 try {
343412 JSONObject studyConfiguration = new JSONObject ("{\" goal\" : \" MAXIMIZE\" , \" randomInitTrials\" : 3, \" maxTrials\" : 5, \" maxParallelTrials\" : 1, \" params\" : [ { \" parameterName\" : \" hidden1\" , \" type\" : \" INTEGER\" , \" minValue\" : 1, \" maxValue\" : 10, \" scallingType\" : \" LINEAR\" }, {\" parameterName\" : \" learning_rate\" ,\" type\" : \" DOUBLE\" , \" minValue\" : 0.01, \" maxValue\" : 0.5, \" scallingType\" : \" LINEAR\" } ]}" );
344- Study st =cl .createStudy ("StudyPruebaTest" , studyConfiguration , Study .BAYESIAN_OPTIMIZATION );
413+ Study st =cl .getOrCreateStudy ("StudyPruebaTest" , studyConfiguration , Study .BAYESIAN_OPTIMIZATION );
345414 System .out .println (st .getAlgorithm ());
346415
347416 } catch (Exception e ) {
0 commit comments