@@ -11,44 +11,33 @@ import (
11
11
"net/http"
12
12
"net/http/httputil"
13
13
"os"
14
- "strings"
15
14
16
15
"github.com/caddyserver/certmagic"
17
16
"github.com/google/go-sev-guest/abi"
18
17
"github.com/google/go-sev-guest/client"
19
18
log "github.com/sirupsen/logrus"
19
+ "gopkg.in/yaml.v3"
20
20
21
21
"github.com/tinfoilanalytics/verifier/pkg/attestation"
22
22
)
23
23
24
24
var version = "dev"
25
25
26
- var (
27
- listenAddr = flag .String ("l" , ":443" , "listen address" )
28
- staging = flag .Bool ("s" , false , "use staging CA" )
29
- upstream = flag .Int ("u" , 8080 , "upstream port" )
30
- allowedPaths = flag .String ("p" , "" , "Paths to proxy to the upstream server (all if empty)" )
31
- certCache = flag .String ("c" , "/mnt/ramdisk/certs" , "certificate cache directory" )
32
- verbose = flag .Bool ("v" , false , "verbose logging" )
33
-
34
-
35
- )
36
-
37
- // cmdlineParam returns the value of a parameter from the kernel command line
38
- func cmdlineParam (key string ) (string , error ) {
39
- cmdline , err := os .ReadFile ("/proc/cmdline" )
40
- if err != nil {
41
- return "" , err
42
- }
26
+ var config struct {
27
+ Domain string `yaml:"domain"`
28
+ ListenPort int `yaml:"listen-port"`
29
+ UpstreamPort int `yaml:"upstream-port"`
30
+ Paths []string `yaml:"paths"`
31
+ StagingCA bool `yaml:"staging-ca"`
32
+ Verbose bool `yaml:"verbose"`
33
+ }
43
34
44
- for _ , p := range strings .Split (string (cmdline ), " " ) {
45
- if strings .HasPrefix (p , key + "=" ) {
46
- return strings .TrimPrefix (p , key + "=" ), nil
47
- }
48
- }
35
+ var (
36
+ configFile = flag .String ("c" , "/mnt/ramdisk/shim.yml" , "Path to config file" )
49
37
50
- return "" , fmt .Errorf ("missing %s" , key )
51
- }
38
+
39
+ certCache = "/mnt/ramdisk/certs"
40
+ )
52
41
53
42
// attestationReport gets a SEV-SNP signed attestation report over a TLS certificate fingerprint
54
43
func attestationReport (certFP string ) (* attestation.Document , error ) {
@@ -90,36 +79,39 @@ func cors(w http.ResponseWriter, r *http.Request) {
90
79
91
80
func main () {
92
81
flag .Parse ()
93
- if * verbose {
94
- log .SetLevel (log .DebugLevel )
95
- }
96
82
97
- domain , err := cmdlineParam ( "tinfoil-domain" )
83
+ configBytes , err := os . ReadFile ( * configFile )
98
84
if err != nil {
99
- log .Fatal (err )
85
+ log .Fatalf ("Failed to read config file: %v" , err )
86
+ }
87
+ if err := yaml .Unmarshal (configBytes , & config ); err != nil {
88
+ log .Fatalf ("Failed to unmarshal config: %v" , err )
89
+ }
90
+
91
+ if config .Verbose {
92
+ log .SetLevel (log .DebugLevel )
100
93
}
101
94
102
- paths := strings .Split (* allowedPaths , "," )
103
- log .Printf ("Starting SEV-SNP attestation shim %s domain %s paths %s" , version , domain , paths )
95
+ log .Printf ("Starting SEV-SNP attestation shim %s domain %s paths %s" , version , config .Domain , config .Paths )
104
96
105
97
mux := http .NewServeMux ()
106
98
107
99
// Request TLS certificate
108
- certmagic .Default .Storage = & certmagic.FileStorage {Path : * certCache }
100
+ certmagic .Default .Storage = & certmagic.FileStorage {Path : certCache }
109
101
certmagic .DefaultACME .Email = email
110
- if * staging {
102
+ if config . StagingCA {
111
103
certmagic .DefaultACME .CA = certmagic .LetsEncryptStagingCA
112
104
} else {
113
105
certmagic .DefaultACME .CA = certmagic .LetsEncryptProductionCA
114
106
}
115
- tlsConfig , err := certmagic .TLS ([]string {domain })
107
+ tlsConfig , err := certmagic .TLS ([]string {config . Domain })
116
108
if err != nil {
117
109
log .Fatalf ("Failed to get TLS config: %v" , err )
118
110
}
119
111
120
112
// Get certificate from TLS config
121
113
cert , err := tlsConfig .GetCertificate (& tls.ClientHelloInfo {
122
- ServerName : domain ,
114
+ ServerName : config . Domain ,
123
115
})
124
116
if err != nil {
125
117
log .Fatalf ("Failed to get certificate: %v" , err )
@@ -137,9 +129,9 @@ func main() {
137
129
mux .HandleFunc ("/" , func (w http.ResponseWriter , r * http.Request ) {
138
130
// cors(w, r)
139
131
140
- if len (paths ) > 0 {
132
+ if len (config . Paths ) > 0 {
141
133
allowed := false
142
- for _ , path := range paths {
134
+ for _ , path := range config . Paths {
143
135
if r .URL .Path == path {
144
136
allowed = true
145
137
break
@@ -155,7 +147,7 @@ func main() {
155
147
Director : func (req * http.Request ) {
156
148
log .Debugf ("Orig to %+v" , req .Header )
157
149
req .URL .Scheme = "http"
158
- req .URL .Host = fmt .Sprintf ("127.0.0.1:%d" , * upstream )
150
+ req .URL .Host = fmt .Sprintf ("127.0.0.1:%d" , config . UpstreamPort )
159
151
req .Header .Set ("Host" , "localhost" )
160
152
req .Host = "localhost"
161
153
log .Debugf ("Proxying request to %+v" , req .URL .String ())
@@ -171,11 +163,12 @@ func main() {
171
163
json .NewEncoder (w ).Encode (att )
172
164
})
173
165
166
+ listenAddr := fmt .Sprintf (":%d" , config .ListenPort )
174
167
httpServer := & http.Server {
175
- Addr : * listenAddr ,
168
+ Addr : listenAddr ,
176
169
Handler : mux ,
177
170
TLSConfig : tlsConfig ,
178
171
}
179
- log .Printf ("Listening on %s" , * listenAddr )
172
+ log .Printf ("Listening on %s" , listenAddr )
180
173
log .Fatal (httpServer .ListenAndServeTLS ("" , "" ))
181
174
}
0 commit comments