-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathauth.go
132 lines (118 loc) · 2.88 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package main
import (
"bufio"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"sync"
"golang.org/x/oauth2"
)
// CachingTokenSource implements a TokenSource that writes its token to
// disk on a successful fetch.
type CachingTokenSource struct {
oauth2.TokenSource
Path string
}
func (cts *CachingTokenSource) Token() (*oauth2.Token, error) {
if token, err := cts.get(); errors.Is(err, os.ErrNotExist) {
token, err = cts.TokenSource.Token()
if err != nil {
return nil, err
}
err = cts.put(token)
return token, err
} else if err != nil {
return nil, err
} else {
return token, nil
}
}
func (cts *CachingTokenSource) get() (*oauth2.Token, error) {
f, err := os.Open(cts.Path)
if err != nil {
return nil, err
}
r := bufio.NewReader(f)
var token oauth2.Token
err = json.NewDecoder(r).Decode(&token)
return &token, err
}
func (cts *CachingTokenSource) put(token *oauth2.Token) error {
f, err := os.Create(cts.Path)
if err != nil {
return err
}
w := bufio.NewWriter(f)
defer w.Flush()
encoder := json.NewEncoder(w)
return encoder.Encode(token)
}
// LocalServerTokenSource implements a TokenSource by starting a local server to
// implement the standard oauth2 flow.
type LocalServerTokenSource struct {
Config oauth2.Config
}
func (p *LocalServerTokenSource) Token() (*oauth2.Token, error) {
ctx := context.Background()
state, err := newState()
if err != nil {
return nil, fmt.Errorf("generate csrf token: %s", err)
}
url := p.Config.AuthCodeURL(state, oauth2.AccessTypeOffline)
fmt.Printf("open this URL in the browser to authenticate.\n\n%s\n", url)
resp, err := waitForCallback(":4000", state)
if err != nil {
return nil, fmt.Errorf("wait for callback: %s", err)
}
if resp.State != state {
return nil, fmt.Errorf("callback state mismatch")
}
return p.Config.Exchange(ctx, resp.Code, oauth2.AccessTypeOffline)
}
func newState() (string, error) {
buf := make([]byte, 24)
_, err := rand.Read(buf)
if err != nil {
return "", err
}
s := base64.URLEncoding.EncodeToString(buf)
return s, nil
}
type callbackResponse struct {
Code string
State string
}
func waitForCallback(addr, csrfToken string) (resp callbackResponse, err error) {
defer func() {
if v := recover(); v != nil {
err = fmt.Errorf("server panicked")
}
}()
c := make(chan callbackResponse)
var once sync.Once
server := &http.Server{
Addr: addr,
Handler: http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
code := req.FormValue("code")
state := req.FormValue("state")
once.Do(func() {
c <- callbackResponse{Code: code, State: state}
})
res.Write([]byte("✅ Go back to your terminal."))
}),
}
go func() {
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
panic(err)
}
}()
// TODO: Add a timeout
resp = <-c
err = server.Shutdown(context.TODO())
return
}