diff --git a/src/ring/middleware/oauth2.clj b/src/ring/middleware/oauth2.clj index f4a52f9..8539353 100644 --- a/src/ring/middleware/oauth2.clj +++ b/src/ring/middleware/oauth2.clj @@ -114,26 +114,29 @@ (defn- access-token-http-options [{:keys [access-token-uri client-id client-secret basic-auth?] - :or {basic-auth? false} :as profile} - request] + :or {basic-auth? false}} + form-params] (let [opts {:method :post :url access-token-uri :accept :json :as :json - :form-params (request-params profile request)}] + :form-params form-params}] (if basic-auth? (add-header-credentials opts client-id client-secret) (add-form-credentials opts client-id client-secret)))) (defn- get-access-token ([profile request] - (-> (http/request (access-token-http-options profile request)) + (-> (access-token-http-options profile (request-params profile request)) + http/request (format-access-token))) ([profile request respond raise] - (http/request (-> (access-token-http-options profile request) - (assoc :async? true)) - (comp respond format-access-token) - raise))) + (http/request + (-> (access-token-http-options profile + (request-params profile request)) + (assoc :async? true)) + (comp respond format-access-token) + raise))) (defn state-mismatch-handler ([_] @@ -188,33 +191,142 @@ (respond (redirect-response profile session token))) raise))))) -(defn- assoc-access-tokens [request] - (if-let [tokens (-> request :session ::access-tokens)] +(defn- get-expired + "Returns expired profile keys and refresh tokens in `access-tokens`." + [access-tokens] + (let [now (new Date)] + (for [[profile-key {:keys [expires refresh-token]}] access-tokens + :when (and expires refresh-token (.before expires now))] + {:profile-key profile-key :refresh-token refresh-token}))) + +(defn- update-tokens + "If `maybe-grant` is nil, removes `profile-key` from `access-token; otherwise + merges `profile-key` with `maybe-grant`." + [access-tokens [profile-key maybe-grant]] + (if maybe-grant + ;; `update ... merge` to properly handle case where authorization server + ;; does not update the refresh token after use and we should re-use the + ;; existing refresh token + (update access-tokens profile-key merge maybe-grant) + (dissoc access-tokens profile-key))) + +(def socket-timeout 60000) + +(defn- refresh-one-token + ([profile refresh-token] + (-> (access-token-http-options + profile + {:grant_type "refresh_token" :refresh_token refresh-token}) + (assoc :socket-timeout socket-timeout) + http/request + format-access-token)) + ([profile refresh-token respond raise] + (-> (access-token-http-options + profile + {:grant_type "refresh_token" + :refresh_token refresh-token}) + (assoc :async? true + :socket-timeout socket-timeout) + (http/request (comp respond format-access-token) raise)))) + +(defn- valid-token? [token] + (and token (string? token) (not (str/blank? token)))) + +(defn- refresh-all-tokens + "Refreshes all expired tokens, yielding an updated map of tokens" + ([profiles access-tokens] + (let [refresh-results + (for [{:keys [profile-key refresh-token]} (get-expired access-tokens) + :let [profile (profile-key profiles)] + :when (and profile (valid-token? refresh-token))] + [profile-key + (try (refresh-one-token profile refresh-token) + (catch clojure.lang.ExceptionInfo _ + nil))])] + (reduce update-tokens access-tokens refresh-results))) + ([profiles access-tokens respond] + ;; strategy: launch all requests concurrently, keeping track of completed + ;; requests in `results`. When all requests have finished, respond. + (let [expired (get-expired access-tokens) + total (count expired) + results (atom {}) ;; map from profile-key to result + respond-when-done! #(when (= (count @results) total) + (respond (reduce update-tokens access-tokens @results)))] + (if (zero? total) + (respond access-tokens) + (doseq [{:keys [profile-key refresh-token]} expired + :let [profile (profile-key profiles)] + :when (and profile (valid-token? refresh-token))] + (refresh-one-token profile refresh-token + (fn [refresh-result] + (swap! results assoc profile-key refresh-result) + (respond-when-done!)) + (fn [_] + (swap! results assoc profile-key nil) + (respond-when-done!)))))))) + +(defn- assoc-access-tokens-in-request [request tokens] + (if tokens (assoc request :oauth2/access-tokens tokens) request)) +(defn- assoc-access-tokens-in-response + "If any tokens are present, adds to them the `:session` key of `response`." + [response tokens] + (if tokens + (assoc-in response [:session ::access-tokens] tokens) + response)) + (defn- parse-redirect-url [{:keys [redirect-uri]}] (.getPath (java.net.URI. redirect-uri))) (defn- valid-profile? [{:keys [client-id client-secret]}] (and (some? client-id) (some? client-secret))) -(defn wrap-oauth2 [handler profiles] +(defn wrap-oauth2 + "Middleware that handles OAuth2 authentication flows. + + Parameters: + * `handler`: The downstream ring handler + * `profiles`: A map of profiles + + Each request URI is matched against the profiles to determine the appropriate + OAuth2 flow handler. If no match is found, the request is passed to the + downstream handler with existing access tokens added to the request under the + `:oauth2/access-tokens` key. + + Expired tokens are refreshed using their refresh-token if possible. If refresh + fails, the access token is removed." + [handler profiles] {:pre [(every? valid-profile? (vals profiles))]} - (let [profiles (for [[k v] profiles] (assoc v :id k)) - launches (into {} (map (juxt :launch-uri identity)) profiles) - redirects (into {} (map (juxt parse-redirect-url identity)) profiles)] + (let [id-profiles (for [[k v] profiles] (assoc v :id k)) + launches (into {} (map (juxt :launch-uri identity)) id-profiles) + redirects (into {} (map (juxt parse-redirect-url identity)) id-profiles)] (fn ([{:keys [uri] :as request}] (if-let [profile (launches uri)] ((make-launch-handler profile) request) (if-let [profile (redirects uri)] ((:redirect-handler profile (make-redirect-handler profile)) request) - (handler (assoc-access-tokens request))))) + (let [access-tokens (get-in request [:session ::access-tokens]) + refreshed-tokens (refresh-all-tokens profiles access-tokens)] + (-> request + (assoc-access-tokens-in-request refreshed-tokens) + handler + (assoc-access-tokens-in-response refreshed-tokens)))))) ([{:keys [uri] :as request} respond raise] (if-let [profile (launches uri)] ((make-launch-handler profile) request respond raise) (if-let [profile (redirects uri)] ((:redirect-handler profile (make-redirect-handler profile)) request respond raise) - (handler (assoc-access-tokens request) respond raise))))))) + (let [access-tokens (get-in request [:session ::access-tokens]) + respond (fn [refreshed-tokens] + (handler + (assoc-access-tokens-in-request + request refreshed-tokens) + (comp respond + #(assoc-access-tokens-in-response + % refreshed-tokens)) + raise))] + (refresh-all-tokens profiles access-tokens respond)))))))) diff --git a/test/ring/middleware/oauth2_test.clj b/test/ring/middleware/oauth2_test.clj index 92151f1..652e72c 100644 --- a/test/ring/middleware/oauth2_test.clj +++ b/test/ring/middleware/oauth2_test.clj @@ -109,8 +109,9 @@ b-ms (.getTime b)] (< (- a-ms 1000) b-ms (+ a-ms 1000)))) -(defn- seconds-from-now-to-date [secs] - (-> (Instant/now) (.plusSeconds secs) (Date/from))) +(defn- seconds-from-now-to-date + ([now secs] (-> now (.plusSeconds secs) (Date/from))) + ([secs] (seconds-from-now-to-date (Instant/now) secs))) (deftest test-redirect-uri (fake/with-fake-routes @@ -203,7 +204,10 @@ (deftest test-access-tokens-key (let [tokens {:test {:token "defdef", :expires 3600}}] - (is (= {:status 200, :headers {}, :body tokens} + (is (= {:status 200, + :headers {}, + :body tokens, + :session {::oauth2/access-tokens tokens}} (-> (mock/request :get "/") (assoc :session {::oauth2/access-tokens tokens}) (test-handler)))))) @@ -376,10 +380,11 @@ (deftest test-handler-passthrough (let [tokens {:test "tttkkkk"} + session {::oauth2/access-tokens tokens} request (-> (mock/request :get "/example") - (assoc :session {::oauth2/access-tokens tokens}))] + (assoc :session session))] (testing "sync handler" - (is (= {:status 200, :headers {}, :body tokens} + (is (= {:status 200, :headers {}, :body tokens :session session} (test-handler request)))) (testing "async handler" @@ -388,5 +393,93 @@ (test-handler request respond raise) (is (= :empty (deref raise 100 :empty))) - (is (= {:status 200, :headers {}, :body tokens} + (is (= {:status 200, :headers {}, :body tokens :session session} (deref respond 100 :empty))))))) + +(def refresh-token-response + {:status 200 + :headers {"Content-Type" "application/json"} + :body "{\"access_token\":\"newtoken\",\"expires_in\":3600, + \"refresh_token\":\"newrefresh\",\"foo\":\"bar\"}"}) + +(deftest test-token-refresh-success + (fake/with-fake-routes + {"https://example.com/oauth2/access-token" + (fn [req] + (let [params (codec/form-decode (slurp (:body req)))] + (is (= "refresh_token" (get params "grant_type"))) + (is (= "oldrefresh" (get params "refresh_token"))) + refresh-token-response))} + + (let [now (Instant/now) + old-expires (seconds-from-now-to-date now -60) + new-expires (seconds-from-now-to-date now 3600) + new-token {:token "newtoken" + :refresh-token "newrefresh" + :extra-data {:foo "bar"}} + request (-> (mock/request :get "/") + (assoc :session + {::oauth2/access-tokens + {:test {:token "oldtoken" + :refresh-token "oldrefresh" + :expires old-expires}}}))] + (testing "sync refresh" + (let [response (test-handler request)] + (is (= 200 (:status response))) + ;; then handler has new token + (is (= new-token (dissoc (get-in response [:body :test]) :expires))) + (is (approx-eq new-expires (get-in response [:body :test :expires]))) + ;; and the user's session is updated + (is (= new-token (dissoc (get-in response [:session ::oauth2/access-tokens :test]) + :expires))))) + (testing "async refresh" + (let [respond (promise) + raise (promise)] + (test-handler request respond raise) + (is (= :empty (deref raise 100 :empty))) + (let [response (deref respond 100 :empty)] + ;; then handler has new token + (is (not= response :empty)) + (is (= new-token (dissoc (get-in response [:body :test]) :expires))) + ;; user session is updated + (is (= new-token + (dissoc (get-in response [:session ::oauth2/access-tokens + :test]) + :expires))))))))) + +(def refresh-token-error-response + {:headers {"content-type" "application/json"}, + :status 400, + :body "{\"error\": \"invalid_grant\"}"}) + +(deftest test-token-refresh-failure + (fake/with-fake-routes + {"https://example.com/oauth2/access-token" + (constantly refresh-token-error-response)} + + ;; setup a session with two grants, where one grant is expired and which + ;; will error on refresh + (let [profiles {:test-0 test-profile :test-1 test-profile} + handler (wrap-oauth2 token-handler profiles) + good-grant {:token "good-token" + :refresh-token "refresh-token" + :expires (seconds-from-now-to-date 3600)} + expired-grant {:token "expired-token" + :refresh-token "invalid" + :expires (seconds-from-now-to-date -60)} + request (-> (mock/request :get "/") + (assoc :session + {::oauth2/access-tokens + {:test-0 expired-grant + :test-1 good-grant}}))] + (testing "sync handler" + (let [response (handler request)] + (is (= {:test-1 good-grant} + (:body response))))) + (testing "async refresh" + (let [respond (promise) + raise (promise)] + (handler request respond raise) + (is (= :empty (deref raise 100 :empty))) + (let [response (deref respond 100 :empty)] + (is (= {:test-1 good-grant} (:body response)))))))))