diff --git a/client/client.go b/client/client.go index f41fd6a41..052ba0d40 100644 --- a/client/client.go +++ b/client/client.go @@ -117,11 +117,14 @@ type ApiClient struct { http.Client } -// Return a new ApiRequest sharing this ApiClient helper -func (a *ApiClient) Request(code AuthToken) *ApiRequest { +type ClientReauthorizeFunc func() (AuthToken, error) + +// Return a new ApiRequest +func (a *ApiClient) Request(code AuthToken, req ClientReauthorizeFunc) *ApiRequest { return &ApiRequest{ - api: a, - auth: code, + api: a, + auth: code, + revoke: req, } } @@ -131,14 +134,36 @@ func (a *ApiClient) Request(code AuthToken) *ApiRequest { type ApiRequest struct { api *ApiClient // authorization code to use for requests - auth AuthToken + auth AuthToken + revoke ClientReauthorizeFunc } func (ar *ApiRequest) Do(req *http.Request) (*http.Response, error) { if req.Header.Get("Authorization") == "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ar.auth)) } - return ar.api.Do(req) + r, err := ar.api.Do(req) + if r != nil && r.StatusCode == http.StatusUnauthorized { + // invalid JWT; most likely the token is expired: + // Try to refresh it and reattempt sending the request + log.Info("Device unauthorized; attempting reauthorization") + if jwt, e := ar.revoke(); e == nil { + // retry API request with new JWT token + ar.auth = jwt + // check if request had a body + // (GetBody is optional, and nil if body is empty) + if req.GetBody != nil { + if body, e := req.GetBody(); e == nil { + req.Body = body + } + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ar.auth)) + r, err = ar.api.Do(req) + } else { + log.Warnf("Reauthorization failed with error: %s", e.Error()) + } + } + return r, err } func NewApiClient(conf Config) (*ApiClient, error) { diff --git a/client/client_test.go b/client/client_test.go index 51778bdb9..cf8c6bd60 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -16,6 +16,7 @@ package client import ( "bytes" "encoding/json" + "errors" "io/ioutil" "net" "net/http" @@ -30,6 +31,9 @@ import ( "github.com/stretchr/testify/require" ) +func dummy() (AuthToken, error) { + return AuthToken(""), errors.New("") +} func TestHttpClient(t *testing.T) { cl, err := NewApiClient( Config{"server.crt", true, false}, @@ -54,7 +58,7 @@ func TestApiClientRequest(t *testing.T) { ) assert.NotNil(t, cl) - req := cl.Request("foobar") + req := cl.Request("foobar", dummy) assert.NotNil(t, req) responder := &struct { @@ -111,7 +115,7 @@ func TestClientConnectionTimeout(t *testing.T) { assert.NotNil(t, cl) assert.NoError(t, err) - req := cl.Request("foobar") + req := cl.Request("foobar", dummy) assert.NotNil(t, req) hreq, err := http.NewRequest(http.MethodGet, ts.URL, nil) diff --git a/mender.go b/mender.go index d5f269aeb..e53bbdd16 100644 --- a/mender.go +++ b/mender.go @@ -464,7 +464,7 @@ func (m *mender) CheckUpdate() (*client.UpdateResponse, menderError) { if err != nil { log.Errorf("Unable to verify the existing hardware. Update will continue anyways: %v : %v", defaultDeviceTypeFile, err) } - haveUpdate, err := m.updater.GetScheduledUpdate(m.api.Request(m.authToken), + haveUpdate, err := m.updater.GetScheduledUpdate(m.api.Request(m.authToken, reauthorize(m)), m.config.ServerURL, client.CurrentUpdate{ Artifact: currentArtifactName, DeviceType: deviceType, @@ -477,6 +477,7 @@ func (m *mender) CheckUpdate() (*client.UpdateResponse, menderError) { if remErr := m.authMgr.RemoveAuthToken(); remErr != nil { log.Warn("can not remove rejected authentication token") } + } log.Error("Error receiving scheduled update data: ", err) return nil, NewTransientError(err) @@ -502,7 +503,7 @@ func (m *mender) CheckUpdate() (*client.UpdateResponse, menderError) { func (m *mender) ReportUpdateStatus(update client.UpdateResponse, status string) menderError { s := client.NewStatus() - err := s.Report(m.api.Request(m.authToken), m.config.ServerURL, + err := s.Report(m.api.Request(m.authToken, reauthorize(m)), m.config.ServerURL, client.StatusReport{ DeploymentID: update.ID, Status: status, @@ -516,9 +517,7 @@ func (m *mender) ReportUpdateStatus(update client.UpdateResponse, status string) if remErr := m.authMgr.RemoveAuthToken(); remErr != nil { log.Warn("can not remove rejected authentication token") } - } - - if errCause == client.ErrDeploymentAborted { + } else if errCause == client.ErrDeploymentAborted { return NewFatalError(err) } return NewTransientError(err) @@ -526,9 +525,23 @@ func (m *mender) ReportUpdateStatus(update client.UpdateResponse, status string) return nil } +func reauthorize(m *mender) func() (client.AuthToken, error) { + // force reauthorization + return func() (client.AuthToken, error) { + // assume token is invalid - remove from storage + if err := m.authMgr.RemoveAuthToken(); err != nil { + return noAuthToken, errors.New("Failed to remove auth token") + } + if err := m.Authorize(); err != nil { + return noAuthToken, err + } + return m.authMgr.AuthToken() + } +} + func (m *mender) UploadLog(update client.UpdateResponse, logs []byte) menderError { s := client.NewLog() - err := s.Upload(m.api.Request(m.authToken), m.config.ServerURL, + err := s.Upload(m.api.Request(m.authToken, reauthorize(m)), m.config.ServerURL, client.LogData{ DeploymentID: update.ID, Messages: logs, @@ -649,7 +662,7 @@ func (m *mender) TransitionState(to State, ctx *StateContext) (State, bool) { log.Error(err) } else { report = &client.StatusReportWrapper{ - API: m.api.Request(m.authToken), + API: m.api.Request(m.authToken, reauthorize(m)), URL: m.config.ServerURL, Report: client.StatusReport{ DeploymentID: upd.ID, @@ -733,7 +746,7 @@ func (m *mender) InventoryRefresh() error { return nil } - err = ic.Submit(m.api.Request(m.authToken), m.config.ServerURL, idata) + err = ic.Submit(m.api.Request(m.authToken, reauthorize(m)), m.config.ServerURL, idata) if err != nil { return errors.Wrapf(err, "failed to submit inventory data") } diff --git a/mender_test.go b/mender_test.go index 4df11a867..553dfbbc5 100644 --- a/mender_test.go +++ b/mender_test.go @@ -537,6 +537,7 @@ func TestMenderReportStatus(t *testing.T) { // 3. pretend that deployment was aborted srv.Reset() + srv.Auth.Authorize = true srv.Auth.Token = []byte("tokendata") srv.Auth.Verify = true srv.Status.Aborted = true