66package microsoft // import "golang.org/x/oauth2/microsoft"
77
88import (
9+ "context"
10+ "crypto/sha1"
11+ "crypto/x509"
12+ "encoding/base64"
13+ "encoding/json"
14+ "encoding/pem"
15+ "fmt"
16+ "io"
17+ "io/ioutil"
18+ "net/http"
19+ "net/url"
20+ "strings"
21+ "time"
22+
923 "golang.org/x/oauth2"
24+ "golang.org/x/oauth2/internal"
25+ "golang.org/x/oauth2/jws"
1026)
1127
12- // LiveConnectEndpoint is Windows's Live ID OAuth 2.0 endpoint.
13- var LiveConnectEndpoint = oauth2.Endpoint {
14- AuthURL : "https://login.live.com/oauth20_authorize.srf" ,
15- TokenURL : "https://login.live.com/oauth20_token.srf" ,
16- }
17-
1828// AzureADEndpoint returns a new oauth2.Endpoint for the given tenant at Azure Active Directory.
1929// If tenant is empty, it uses the tenant called `common`.
2030//
@@ -29,3 +39,161 @@ func AzureADEndpoint(tenant string) oauth2.Endpoint {
2939 TokenURL : "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/token" ,
3040 }
3141}
42+
43+ // Config is the configuration for using client credentials flow with a client assertion.
44+ //
45+ // For more information see:
46+ // https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-certificate-credentials
47+ type Config struct {
48+ // ClientID is the application's ID.
49+ ClientID string
50+
51+ // PrivateKey contains the contents of an RSA private key or the
52+ // contents of a PEM file that contains a private key. The provided
53+ // private key is used to sign JWT assertions.
54+ // PEM containers with a passphrase are not supported.
55+ // You can use pkcs12.Decode to extract the private key and certificate
56+ // from a PKCS #12 archive, or alternatively with OpenSSL:
57+ //
58+ // $ openssl pkcs12 -in key.p12 -out key.pem -nodes
59+ //
60+ PrivateKey []byte
61+
62+ // Certificate contains the (optionally PEM encoded) X509 certificate registered
63+ // for the application with which you are authenticating.
64+ Certificate []byte
65+
66+ // Scopes optionally specifies a list of requested permission scopes.
67+ Scopes []string
68+
69+ // TokenURL is the token endpoint. Typically you can use the AzureADEndpoint
70+ // function to obtain this value, but it may change for non-public clouds.
71+ TokenURL string
72+
73+ // Expires optionally specifies how long the token is valid for.
74+ Expires time.Duration
75+
76+ // Audience optionally specifies the intended audience of the
77+ // request. If empty, the value of TokenURL is used as the
78+ // intended audience.
79+ Audience string
80+ }
81+
82+ // TokenSource returns a TokenSource using the configuration
83+ // in c and the HTTP client from the provided context.
84+ func (c * Config ) TokenSource (ctx context.Context ) oauth2.TokenSource {
85+ return oauth2 .ReuseTokenSource (nil , assertionSource {ctx , c })
86+ }
87+
88+ // Client returns an HTTP client wrapping the context's
89+ // HTTP transport and adding Authorization headers with tokens
90+ // obtained from c.
91+ //
92+ // The returned client and its Transport should not be modified.
93+ func (c * Config ) Client (ctx context.Context ) * http.Client {
94+ return oauth2 .NewClient (ctx , c .TokenSource (ctx ))
95+ }
96+
97+ // assertionSource is a source that always does a signed request for a token.
98+ // It should typically be wrapped with a reuseTokenSource.
99+ type assertionSource struct {
100+ ctx context.Context
101+ conf * Config
102+ }
103+
104+ // Token refreshes the token by using a new client credentials request with signed assertion.
105+ func (a assertionSource ) Token () (* oauth2.Token , error ) {
106+ crt := a .conf .Certificate
107+ if der , _ := pem .Decode (a .conf .Certificate ); der != nil {
108+ crt = der .Bytes
109+ }
110+ cert , err := x509 .ParseCertificate (crt )
111+ if err != nil {
112+ return nil , fmt .Errorf ("oauth2: cannot parse certificate: %v" , err )
113+ }
114+ s := sha1 .Sum (cert .Raw )
115+ fp := base64 .URLEncoding .EncodeToString (s [:])
116+ h := jws.Header {
117+ Algorithm : "RS256" ,
118+ Typ : "JWT" ,
119+ KeyID : fp ,
120+ }
121+
122+ claimSet := & jws.ClaimSet {
123+ Iss : a .conf .ClientID ,
124+ Sub : a .conf .ClientID ,
125+ Aud : a .conf .TokenURL ,
126+ }
127+ if t := a .conf .Expires ; t > 0 {
128+ claimSet .Exp = time .Now ().Add (t ).Unix ()
129+ }
130+ if aud := a .conf .Audience ; aud != "" {
131+ claimSet .Aud = aud
132+ }
133+
134+ pk , err := internal .ParseKey (a .conf .PrivateKey )
135+ if err != nil {
136+ return nil , err
137+ }
138+
139+ payload , err := jws .Encode (& h , claimSet , pk )
140+ if err != nil {
141+ return nil , err
142+ }
143+
144+ hc := oauth2 .NewClient (a .ctx , nil )
145+ v := url.Values {
146+ "client_assertion" : {payload },
147+ "client_assertion_type" : {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer" },
148+ "client_id" : {a .conf .ClientID },
149+ "grant_type" : {"client_credentials" },
150+ "scope" : {strings .Join (a .conf .Scopes , " " )},
151+ }
152+ resp , err := hc .PostForm (a .conf .TokenURL , v )
153+ if err != nil {
154+ return nil , fmt .Errorf ("oauth2: cannot fetch token: %v" , err )
155+ }
156+
157+ defer resp .Body .Close ()
158+ body , err := ioutil .ReadAll (io .LimitReader (resp .Body , 1 << 20 ))
159+ if err != nil {
160+ return nil , fmt .Errorf ("oauth2: cannot fetch token: %v" , err )
161+ }
162+
163+ if c := resp .StatusCode ; c < 200 || c > 299 {
164+ return nil , & oauth2.RetrieveError {
165+ Response : resp ,
166+ Body : body ,
167+ }
168+ }
169+
170+ var tokenRes struct {
171+ AccessToken string `json:"access_token"`
172+ TokenType string `json:"token_type"`
173+ IDToken string `json:"id_token"`
174+ Scope string `json:"scope"`
175+ ExpiresIn int64 `json:"expires_in"` // relative seconds from now
176+ ExpiresOn int64 `json:"expires_on"` // timestamp
177+ }
178+ if err := json .Unmarshal (body , & tokenRes ); err != nil {
179+ return nil , fmt .Errorf ("oauth2: cannot fetch token: %v" , err )
180+ }
181+
182+ token := & oauth2.Token {
183+ AccessToken : tokenRes .AccessToken ,
184+ TokenType : tokenRes .TokenType ,
185+ }
186+ if secs := tokenRes .ExpiresIn ; secs > 0 {
187+ token .Expiry = time .Now ().Add (time .Duration (secs ) * time .Second )
188+ }
189+ if v := tokenRes .IDToken ; v != "" {
190+ // decode returned id token to get expiry
191+ claimSet , err := jws .Decode (v )
192+ if err != nil {
193+ return nil , fmt .Errorf ("oauth2: error decoding JWT token: %v" , err )
194+ }
195+ token .Expiry = time .Unix (claimSet .Exp , 0 )
196+ }
197+
198+ return token , nil
199+ }
0 commit comments