From ab46e58db25bfe467bc115b9fdde19818276d93d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Wa=C5=9Bko?= Date: Thu, 9 Jan 2025 15:46:12 +0100 Subject: [PATCH 01/17] PoC pt 1 --- .../Snowflake/0.0.0-dev/src/OAuth_Test.enso | 28 ++++++++++++++++ .../org/enso/snowflake/OAuthCallback.java | 32 +++++++++++++++++++ test-snowflake-oauth.enso | 8 +++++ 3 files changed, 68 insertions(+) create mode 100644 distribution/lib/Standard/Snowflake/0.0.0-dev/src/OAuth_Test.enso create mode 100644 std-bits/snowflake/src/main/java/org/enso/snowflake/OAuthCallback.java create mode 100644 test-snowflake-oauth.enso diff --git a/distribution/lib/Standard/Snowflake/0.0.0-dev/src/OAuth_Test.enso b/distribution/lib/Standard/Snowflake/0.0.0-dev/src/OAuth_Test.enso new file mode 100644 index 000000000000..8f870bbb947a --- /dev/null +++ b/distribution/lib/Standard/Snowflake/0.0.0-dev/src/OAuth_Test.enso @@ -0,0 +1,28 @@ +from Standard.Base import all + +import project.Connection.Snowflake_Details.Snowflake_Details + +polyglot java import org.enso.snowflake.OAuthCallback + +initiate_oauth -> Snowflake_Details = + Error.throw "TODO" + +perform_oauth account:Text role:Text = + uri = create_oauth_uri account role refresh_token=False + IO.println "Please open "+uri.to_text+" in your browser and follow the instructions." + # Wait for callback. + result = OAuthCallback.waitForCallback + IO.println "Got callback: "+result.to_text + +create_oauth_uri account:Text role:Text refresh_token:Boolean -> URI = + # TODO add code_challenge for PKCE + # TODO check that account does not contain any unexpected characters + base_uri = URI.from "https://"+account+".snowflakecomputing.com/oauth/authorize" + scope = (if refresh_token then "refresh_token " else "")+"session:role:"+role + base_uri + . add_query_argument "response_type" "code" + . add_query_argument "client_id" client_id + . add_query_argument "redirect_uri" "http://localhost:51234/snowflake" + . add_query_argument "scope" scope + +client_id = "hBjdrf6WyFIyLMMt9Ojbk0c8eto=" diff --git a/std-bits/snowflake/src/main/java/org/enso/snowflake/OAuthCallback.java b/std-bits/snowflake/src/main/java/org/enso/snowflake/OAuthCallback.java new file mode 100644 index 000000000000..e203781e4bc2 --- /dev/null +++ b/std-bits/snowflake/src/main/java/org/enso/snowflake/OAuthCallback.java @@ -0,0 +1,32 @@ +package org.enso.snowflake; + +import com.sun.net.httpserver.HttpServer; + +import java.io.IOException; +import java.net.InetSocketAddress; + +public final class OAuthCallback { + private OAuthCallback() {} + + public static CallbackServer waitForCallback(int port) throws IOException { + var callbackServer = new CallbackServerImplementation(port); + callbackServer.start(); + } + + public interface CallbackServer extends AutoCloseable { + String waitForCallback(); + } + + private static final class CallbackServerImplementation implements CallbackServer { + private final HttpServer server; + + private CallbackServerImplementation(int port) throws IOException { + InetSocketAddress address = new InetSocketAddress("localhost", port); + server = HttpServer.create(address, 0); + } + + private void start() { + server.start(); + } + } +} diff --git a/test-snowflake-oauth.enso b/test-snowflake-oauth.enso new file mode 100644 index 000000000000..1225989a773b --- /dev/null +++ b/test-snowflake-oauth.enso @@ -0,0 +1,8 @@ +from Standard.Base import all + +import Standard.Snowflake.OAuth_Test + +main = + account = "PUSMBUI-DP01445" + OAuth_Test.perform_oauth account role="CI" + From 0fc1480a92924d6095c5f94844421266172f1f46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Wa=C5=9Bko?= Date: Thu, 9 Jan 2025 16:23:30 +0100 Subject: [PATCH 02/17] server --- build.sbt | 3 +- .../Snowflake/0.0.0-dev/src/OAuth_Test.enso | 13 +++-- .../org/enso/snowflake/OAuthCallback.java | 50 ++++++++++++++++++- 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/build.sbt b/build.sbt index 814e197d7fc7..acfd3e548f41 100644 --- a/build.sbt +++ b/build.sbt @@ -4890,7 +4890,8 @@ lazy val `std-snowflake` = project `std-snowflake-polyglot-root` / "std-snowflake.jar", libraryDependencies ++= Seq( "org.netbeans.api" % "org-openide-util-lookup" % netbeansApiVersion % "provided", - "net.snowflake" % "snowflake-jdbc" % snowflakeJDBCVersion + "net.snowflake" % "snowflake-jdbc" % snowflakeJDBCVersion, + "com.sun.net.httpserver" % "http" % "20070405" ), Compile / packageBin := Def.task { val result = (Compile / packageBin).value diff --git a/distribution/lib/Standard/Snowflake/0.0.0-dev/src/OAuth_Test.enso b/distribution/lib/Standard/Snowflake/0.0.0-dev/src/OAuth_Test.enso index 8f870bbb947a..b52e6b8084cc 100644 --- a/distribution/lib/Standard/Snowflake/0.0.0-dev/src/OAuth_Test.enso +++ b/distribution/lib/Standard/Snowflake/0.0.0-dev/src/OAuth_Test.enso @@ -1,4 +1,5 @@ from Standard.Base import all +import Standard.Base.Runtime.Managed_Resource.Managed_Resource import project.Connection.Snowflake_Details.Snowflake_Details @@ -9,10 +10,14 @@ initiate_oauth -> Snowflake_Details = perform_oauth account:Text role:Text = uri = create_oauth_uri account role refresh_token=False - IO.println "Please open "+uri.to_text+" in your browser and follow the instructions." - # Wait for callback. - result = OAuthCallback.waitForCallback - IO.println "Got callback: "+result.to_text + print_panic caugh_panic = + IO.println "Panic caught: "+caugh_panic.payload.to_text + Panic.catch Any handler=print_panic <| + result = Managed_Resource.bracket (OAuthCallback.createCallbackServer 51234) (.close) server-> + # We start the server first to ensure we can 'reserve' the port before opening the browser + IO.println "Please open "+uri.to_text+" in your browser and follow the instructions." + server.waitForCallback + IO.println "Got callback: "+result.to_text create_oauth_uri account:Text role:Text refresh_token:Boolean -> URI = # TODO add code_challenge for PKCE diff --git a/std-bits/snowflake/src/main/java/org/enso/snowflake/OAuthCallback.java b/std-bits/snowflake/src/main/java/org/enso/snowflake/OAuthCallback.java index e203781e4bc2..2d0a00eb3161 100644 --- a/std-bits/snowflake/src/main/java/org/enso/snowflake/OAuthCallback.java +++ b/std-bits/snowflake/src/main/java/org/enso/snowflake/OAuthCallback.java @@ -4,13 +4,16 @@ import java.io.IOException; import java.net.InetSocketAddress; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; public final class OAuthCallback { private OAuthCallback() {} - public static CallbackServer waitForCallback(int port) throws IOException { + public static CallbackServer createCallbackServer(int port) throws IOException { var callbackServer = new CallbackServerImplementation(port); callbackServer.start(); + return callbackServer; } public interface CallbackServer extends AutoCloseable { @@ -19,14 +22,59 @@ public interface CallbackServer extends AutoCloseable { private static final class CallbackServerImplementation implements CallbackServer { private final HttpServer server; + private final CompletableFuture callbackResult = new CompletableFuture<>(); private CallbackServerImplementation(int port) throws IOException { InetSocketAddress address = new InetSocketAddress("localhost", port); server = HttpServer.create(address, 0); + server.createContext("snowflake", exchange -> { + var query = exchange.getRequestURI().getQuery(); + byte[] response = OK_RESPONSE.getBytes(); + exchange.sendResponseHeaders(200, response.length); + exchange.getResponseBody().write(response); + exchange.close(); + callbackResult.complete(query); + }); } private void start() { server.start(); } + + @Override + public String waitForCallback() { + try { + return callbackResult.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() throws Exception { + server.stop(1); + } + + private static final String OK_RESPONSE = + """ + + + Enso - Snowflake integration + + + +

Enso - Snowflake integration

+

OAuth callback received. You can close this window now and go back to the application.

+ + + """; } } From 9dbbfe131860f180d58cea2290983ecb43e64482 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Wa=C5=9Bko?= Date: Thu, 9 Jan 2025 20:27:04 +0100 Subject: [PATCH 03/17] followup --- build.sbt | 3 +- .../Base/0.0.0-dev/src/Network/HTTP.enso | 1 + .../Snowflake/0.0.0-dev/src/OAuth_Test.enso | 38 ++++++++++++++++--- .../org/enso/snowflake/OAuthCallback.java | 22 +++++++++-- 4 files changed, 53 insertions(+), 11 deletions(-) diff --git a/build.sbt b/build.sbt index acfd3e548f41..814e197d7fc7 100644 --- a/build.sbt +++ b/build.sbt @@ -4890,8 +4890,7 @@ lazy val `std-snowflake` = project `std-snowflake-polyglot-root` / "std-snowflake.jar", libraryDependencies ++= Seq( "org.netbeans.api" % "org-openide-util-lookup" % netbeansApiVersion % "provided", - "net.snowflake" % "snowflake-jdbc" % snowflakeJDBCVersion, - "com.sun.net.httpserver" % "http" % "20070405" + "net.snowflake" % "snowflake-jdbc" % snowflakeJDBCVersion ), Compile / packageBin := Def.task { val result = (Compile / packageBin).value diff --git a/distribution/lib/Standard/Base/0.0.0-dev/src/Network/HTTP.enso b/distribution/lib/Standard/Base/0.0.0-dev/src/Network/HTTP.enso index bee4ce27cc8b..3a597abc8c9a 100644 --- a/distribution/lib/Standard/Base/0.0.0-dev/src/Network/HTTP.enso +++ b/distribution/lib/Standard/Base/0.0.0-dev/src/Network/HTTP.enso @@ -154,6 +154,7 @@ type HTTP handle_request_error <| Illegal_Argument.handle_java_exception <| check_output_context <| check_cache_policy <| Response_Too_Large.handle_java_exception <| headers = _resolve_headers req + Standard.Base.IO.println headers headers.if_not_error <| resolved_body = _resolve_body req.body self.hash_method resolved_body.if_not_error <| diff --git a/distribution/lib/Standard/Snowflake/0.0.0-dev/src/OAuth_Test.enso b/distribution/lib/Standard/Snowflake/0.0.0-dev/src/OAuth_Test.enso index b52e6b8084cc..4888753730a6 100644 --- a/distribution/lib/Standard/Snowflake/0.0.0-dev/src/OAuth_Test.enso +++ b/distribution/lib/Standard/Snowflake/0.0.0-dev/src/OAuth_Test.enso @@ -1,4 +1,6 @@ from Standard.Base import all +import Standard.Base.Data.Base_64.Base_64 +import Standard.Base.Network.HTTP.Request_Body.Request_Body import Standard.Base.Runtime.Managed_Resource.Managed_Resource import project.Connection.Snowflake_Details.Snowflake_Details @@ -10,14 +12,23 @@ initiate_oauth -> Snowflake_Details = perform_oauth account:Text role:Text = uri = create_oauth_uri account role refresh_token=False - print_panic caugh_panic = - IO.println "Panic caught: "+caugh_panic.payload.to_text + print_panic caught_panic = + IO.println "Panic caught: "+caught_panic.payload.to_text + IO.println caught_panic.convert_to_dataflow_error.get_stack_trace_text Panic.catch Any handler=print_panic <| result = Managed_Resource.bracket (OAuthCallback.createCallbackServer 51234) (.close) server-> # We start the server first to ensure we can 'reserve' the port before opening the browser IO.println "Please open "+uri.to_text+" in your browser and follow the instructions." server.waitForCallback - IO.println "Got callback: "+result.to_text + code_prefix = "code=" + code = result + . split "&" + . find (e-> e.starts_with code_prefix) + . drop code_prefix.length + access_token = exchange_code_for_access_token account code + IO.println access_token + + create_oauth_uri account:Text role:Text refresh_token:Boolean -> URI = # TODO add code_challenge for PKCE @@ -27,7 +38,24 @@ create_oauth_uri account:Text role:Text refresh_token:Boolean -> URI = base_uri . add_query_argument "response_type" "code" . add_query_argument "client_id" client_id - . add_query_argument "redirect_uri" "http://localhost:51234/snowflake" + . add_query_argument "redirect_uri" redirect_uri . add_query_argument "scope" scope + . add_query_argument "state" "foobar" + +type Access_Token + Value token:Text expiry:Date_Time + +exchange_code_for_access_token account:Text code:Text -> Access_Token = + now = Date_Time.now + uri = URI.from "https://"+account+".snowflakecomputing.com/oauth/token-request" + params = Dictionary.from_vector [["grant_type", "authorization_code"], ["code", code], ["redirect_uri", redirect_uri]] + request_body = Request_Body.Form_Data params url_encoded=True + headers = [Header.authorization_basic client_id client_secret] + response = Data.post uri body=request_body headers=headers response_format=JSON_Format + IO.println response + Access_Token.Value (response.get "access_token") (now + Duration.new seconds=(response.get "expires_in")) + -client_id = "hBjdrf6WyFIyLMMt9Ojbk0c8eto=" +client_id = Environment.get "SNOWFLAKE_APP_CLIENT_ID" +client_secret = Environment.get "SNOWFLAKE_APP_CLIENT_SECRET" +redirect_uri = "http://localhost:51234/snowflake" diff --git a/std-bits/snowflake/src/main/java/org/enso/snowflake/OAuthCallback.java b/std-bits/snowflake/src/main/java/org/enso/snowflake/OAuthCallback.java index 2d0a00eb3161..061d04fbbf92 100644 --- a/std-bits/snowflake/src/main/java/org/enso/snowflake/OAuthCallback.java +++ b/std-bits/snowflake/src/main/java/org/enso/snowflake/OAuthCallback.java @@ -11,9 +11,14 @@ public final class OAuthCallback { private OAuthCallback() {} public static CallbackServer createCallbackServer(int port) throws IOException { - var callbackServer = new CallbackServerImplementation(port); - callbackServer.start(); - return callbackServer; + try { + var callbackServer = new CallbackServerImplementation(port); + callbackServer.start(); + return callbackServer; + } catch (Exception e) { + e.printStackTrace(); + throw e; + } } public interface CallbackServer extends AutoCloseable { @@ -27,8 +32,15 @@ private static final class CallbackServerImplementation implements CallbackServe private CallbackServerImplementation(int port) throws IOException { InetSocketAddress address = new InetSocketAddress("localhost", port); server = HttpServer.create(address, 0); - server.createContext("snowflake", exchange -> { + server.createContext("/snowflake", exchange -> { var query = exchange.getRequestURI().getQuery(); +// System.out.println("method = " + exchange.getRequestMethod()); +// System.out.println("query = " + query); +// System.out.println("headers = " + exchange.getRequestHeaders()); +// byte[] body = exchange.getRequestBody().readAllBytes(); +// System.out.println("body = " + new String(body)); + + byte[] response = OK_RESPONSE.getBytes(); exchange.sendResponseHeaders(200, response.length); exchange.getResponseBody().write(response); @@ -63,6 +75,7 @@ public void close() throws Exception {