Skip to content

Commit d48c3ef

Browse files
committed
feat: use config file
1 parent bcb3ce5 commit d48c3ef

File tree

1 file changed

+34
-41
lines changed

1 file changed

+34
-41
lines changed

Diff for: main.go

+34-41
Original file line numberDiff line numberDiff line change
@@ -11,44 +11,33 @@ import (
1111
"net/http"
1212
"net/http/httputil"
1313
"os"
14-
"strings"
1514

1615
"github.com/caddyserver/certmagic"
1716
"github.com/google/go-sev-guest/abi"
1817
"github.com/google/go-sev-guest/client"
1918
log "github.com/sirupsen/logrus"
19+
"gopkg.in/yaml.v3"
2020

2121
"github.com/tinfoilanalytics/verifier/pkg/attestation"
2222
)
2323

2424
var version = "dev"
2525

26-
var (
27-
listenAddr = flag.String("l", ":443", "listen address")
28-
staging = flag.Bool("s", false, "use staging CA")
29-
upstream = flag.Int("u", 8080, "upstream port")
30-
allowedPaths = flag.String("p", "", "Paths to proxy to the upstream server (all if empty)")
31-
certCache = flag.String("c", "/mnt/ramdisk/certs", "certificate cache directory")
32-
verbose = flag.Bool("v", false, "verbose logging")
33-
34-
35-
)
36-
37-
// cmdlineParam returns the value of a parameter from the kernel command line
38-
func cmdlineParam(key string) (string, error) {
39-
cmdline, err := os.ReadFile("/proc/cmdline")
40-
if err != nil {
41-
return "", err
42-
}
26+
var config struct {
27+
Domain string `yaml:"domain"`
28+
ListenPort int `yaml:"listen-port"`
29+
UpstreamPort int `yaml:"upstream-port"`
30+
Paths []string `yaml:"paths"`
31+
StagingCA bool `yaml:"staging-ca"`
32+
Verbose bool `yaml:"verbose"`
33+
}
4334

44-
for _, p := range strings.Split(string(cmdline), " ") {
45-
if strings.HasPrefix(p, key+"=") {
46-
return strings.TrimPrefix(p, key+"="), nil
47-
}
48-
}
35+
var (
36+
configFile = flag.String("c", "/mnt/ramdisk/shim.yml", "Path to config file")
4937

50-
return "", fmt.Errorf("missing %s", key)
51-
}
38+
39+
certCache = "/mnt/ramdisk/certs"
40+
)
5241

5342
// attestationReport gets a SEV-SNP signed attestation report over a TLS certificate fingerprint
5443
func attestationReport(certFP string) (*attestation.Document, error) {
@@ -90,36 +79,39 @@ func cors(w http.ResponseWriter, r *http.Request) {
9079

9180
func main() {
9281
flag.Parse()
93-
if *verbose {
94-
log.SetLevel(log.DebugLevel)
95-
}
9682

97-
domain, err := cmdlineParam("tinfoil-domain")
83+
configBytes, err := os.ReadFile(*configFile)
9884
if err != nil {
99-
log.Fatal(err)
85+
log.Fatalf("Failed to read config file: %v", err)
86+
}
87+
if err := yaml.Unmarshal(configBytes, &config); err != nil {
88+
log.Fatalf("Failed to unmarshal config: %v", err)
89+
}
90+
91+
if config.Verbose {
92+
log.SetLevel(log.DebugLevel)
10093
}
10194

102-
paths := strings.Split(*allowedPaths, ",")
103-
log.Printf("Starting SEV-SNP attestation shim %s domain %s paths %s", version, domain, paths)
95+
log.Printf("Starting SEV-SNP attestation shim %s domain %s paths %s", version, config.Domain, config.Paths)
10496

10597
mux := http.NewServeMux()
10698

10799
// Request TLS certificate
108-
certmagic.Default.Storage = &certmagic.FileStorage{Path: *certCache}
100+
certmagic.Default.Storage = &certmagic.FileStorage{Path: certCache}
109101
certmagic.DefaultACME.Email = email
110-
if *staging {
102+
if config.StagingCA {
111103
certmagic.DefaultACME.CA = certmagic.LetsEncryptStagingCA
112104
} else {
113105
certmagic.DefaultACME.CA = certmagic.LetsEncryptProductionCA
114106
}
115-
tlsConfig, err := certmagic.TLS([]string{domain})
107+
tlsConfig, err := certmagic.TLS([]string{config.Domain})
116108
if err != nil {
117109
log.Fatalf("Failed to get TLS config: %v", err)
118110
}
119111

120112
// Get certificate from TLS config
121113
cert, err := tlsConfig.GetCertificate(&tls.ClientHelloInfo{
122-
ServerName: domain,
114+
ServerName: config.Domain,
123115
})
124116
if err != nil {
125117
log.Fatalf("Failed to get certificate: %v", err)
@@ -137,9 +129,9 @@ func main() {
137129
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
138130
// cors(w, r)
139131

140-
if len(paths) > 0 {
132+
if len(config.Paths) > 0 {
141133
allowed := false
142-
for _, path := range paths {
134+
for _, path := range config.Paths {
143135
if r.URL.Path == path {
144136
allowed = true
145137
break
@@ -155,7 +147,7 @@ func main() {
155147
Director: func(req *http.Request) {
156148
log.Debugf("Orig to %+v", req.Header)
157149
req.URL.Scheme = "http"
158-
req.URL.Host = fmt.Sprintf("127.0.0.1:%d", *upstream)
150+
req.URL.Host = fmt.Sprintf("127.0.0.1:%d", config.UpstreamPort)
159151
req.Header.Set("Host", "localhost")
160152
req.Host = "localhost"
161153
log.Debugf("Proxying request to %+v", req.URL.String())
@@ -171,11 +163,12 @@ func main() {
171163
json.NewEncoder(w).Encode(att)
172164
})
173165

166+
listenAddr := fmt.Sprintf(":%d", config.ListenPort)
174167
httpServer := &http.Server{
175-
Addr: *listenAddr,
168+
Addr: listenAddr,
176169
Handler: mux,
177170
TLSConfig: tlsConfig,
178171
}
179-
log.Printf("Listening on %s", *listenAddr)
172+
log.Printf("Listening on %s", listenAddr)
180173
log.Fatal(httpServer.ListenAndServeTLS("", ""))
181174
}

0 commit comments

Comments
 (0)