Skip to content

Commit 48f8611

Browse files
qmuntaldagood
andauthored
Move error handling to C code (#265)
* move error handling to C code * add error test * fix TestErrorAllocs * undo version loading * fix TestErrorMultithread * autogenerate mkcgoNoEscape * add nosplit * Apply suggestions from code review Co-authored-by: Davis Goodin <[email protected]> * decouple error implementation * improve go_hash_sum error messages --------- Co-authored-by: Davis Goodin <[email protected]>
1 parent 4179116 commit 48f8611

12 files changed

+1895
-935
lines changed

cmd/mkcgo/generate.go

+148-46
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,13 @@ func generateGo(src *mkcgo.Source, w io.Writer) {
1919
// This block outputs C header includes and forward declarations for loader functions.
2020
fmt.Fprintf(w, "/*\n")
2121
fmt.Fprintf(w, "#cgo CFLAGS: -Wno-attributes\n\n")
22-
if *includeHeader != "" {
23-
fmt.Fprintf(w, "#include \"%s\"\n", *includeHeader)
24-
}
2522
for _, file := range src.Files {
2623
fmt.Fprintf(w, "#include %q\n", file)
2724
}
28-
fmt.Fprintf(w, "\n")
29-
for _, tag := range src.Tags() {
30-
fmt.Fprintf(w, "void __mkcgoLoad_%s(void* handle);\n", tag)
31-
fmt.Fprintf(w, "void __mkcgoUnload_%s();\n", tag)
32-
}
33-
fmt.Fprintf(w, "\n")
34-
for _, fn := range src.Funcs {
35-
if fn.Optional {
36-
fmt.Fprintf(w, "int %s_Available();\n", fn.ImportName)
37-
}
25+
if *includeHeader != "" {
26+
fmt.Fprintf(w, "#include \"%s\"\n", *includeHeader)
3827
}
28+
fmt.Fprintf(w, "#include \"%s\"\n", autogeneratedFileName(".h"))
3929
fmt.Fprintf(w, "*/\n")
4030
fmt.Fprintf(w, "import \"C\"\n")
4131
fmt.Fprintf(w, "import \"unsafe\"\n\n")
@@ -56,10 +46,14 @@ func generateGo(src *mkcgo.Source, w io.Writer) {
5646
fmt.Fprintf(w, "}\n\n")
5747
}
5848

59-
typedefs := make(map[string]string, len(src.TypeDefs))
60-
for _, def := range src.TypeDefs {
61-
typedefs[def.Name] = def.Type
62-
}
49+
// Generate error wrapper noescape function, which hides the
50+
// error state pointer from the Go garbage collector.
51+
// An instance of https://github.com/golang/go/blob/d704ef76068eb7da15520b08dc7df98f45f85ffa/src/runtime/stubs.go#L194-L201
52+
fmt.Fprintf(w, "//go:nosplit\n")
53+
fmt.Fprintf(w, "func %s(p *C.%s) *C.%s {\n", mkcgoNoEscape, mkcgoErrState, mkcgoErrState)
54+
fmt.Fprintf(w, "\tx := uintptr(unsafe.Pointer(p))\n")
55+
fmt.Fprintf(w, "\treturn (*C.%s)(unsafe.Pointer(x ^ 0))\n", mkcgoErrState)
56+
fmt.Fprintf(w, "}\n\n")
6357

6458
// Generate function wrappers.
6559
for _, fn := range src.Funcs {
@@ -73,7 +67,7 @@ func generateGo(src *mkcgo.Source, w io.Writer) {
7367
fmt.Fprintf(w, "\treturn C.%s_Available() != 0\n", fn.ImportName)
7468
fmt.Fprintf(w, "}\n\n")
7569
}
76-
generateGoFn(typedefs, fn, w)
70+
generateGoFn(fn, w)
7771
}
7872
}
7973

@@ -87,11 +81,15 @@ func generateGo124(src *mkcgo.Source, w io.Writer) {
8781
// This block outputs C header includes and forward declarations for loader functions.
8882
fmt.Fprintf(w, "/*\n")
8983
for _, fn := range src.Funcs {
84+
name := fn.CName
85+
if fnNeedErrWrapper(fn) {
86+
name = fnCErrWrapperName(fn)
87+
}
9088
if fn.NoEscape {
91-
fmt.Fprintf(w, "#cgo noescape %s\n", fn.CName)
89+
fmt.Fprintf(w, "#cgo noescape %s\n", name)
9290
}
9391
if fn.NoCallback {
94-
fmt.Fprintf(w, "#cgo nocallback %s\n", fn.CName)
92+
fmt.Fprintf(w, "#cgo nocallback %s\n", name)
9593
}
9694
}
9795
fmt.Fprintf(w, "*/\n")
@@ -148,20 +146,71 @@ func generateGoAliases(funcs []*mkcgo.Func, w io.Writer) {
148146
}
149147
}
150148

151-
// generateC creates the C source file content.
152-
func generateC(src *mkcgo.Source, w io.Writer) {
149+
// generateCHeader generates C header file content with
150+
// the C functions defined in the autogenerated C source file.
151+
func generateCHeader(src *mkcgo.Source, w io.Writer) {
153152
// Header and includes.
154153
fmt.Fprintf(w, "// Code generated by mkcgo. DO NOT EDIT.\n\n")
154+
155+
fmt.Fprintf(w, "#ifndef MKCGO_H // only include this header once\n")
156+
fmt.Fprintf(w, "#define MKCGO_H\n\n")
157+
158+
for _, file := range src.Files {
159+
fmt.Fprintf(w, "#include %q\n", file)
160+
}
155161
if *includeHeader != "" {
156162
fmt.Fprintf(w, "#include \"%s\"\n", *includeHeader)
157163
}
158-
for _, file := range src.Files {
159-
fmt.Fprintf(w, "#include %q\n", file)
164+
fmt.Fprintf(w, "\n")
165+
166+
// Custom types
167+
fmt.Fprintf(w, "typedef void* %s;\n", mkcgoErrState)
168+
fmt.Fprintf(w, "%s mkcgo_err_retrieve();\n", mkcgoErrState)
169+
fmt.Fprintf(w, "void mkcgo_err_free(%s);\n", mkcgoErrState)
170+
fmt.Fprintf(w, "void mkcgo_err_clear();\n\n")
171+
172+
// Add forward declarations for loader functions.
173+
for _, tag := range src.Tags() {
174+
fmt.Fprintf(w, "void __mkcgoLoad_%s(void* handle);\n", tag)
175+
fmt.Fprintf(w, "void __mkcgoUnload_%s();\n", tag)
176+
}
177+
fmt.Fprintf(w, "\n")
178+
179+
// Add forward declarations for optional functions.
180+
for _, fn := range src.Funcs {
181+
if fn.Optional {
182+
fmt.Fprintf(w, "int %s_Available();\n", fn.ImportName)
183+
}
184+
}
185+
fmt.Fprintf(w, "\n")
186+
187+
// Add forward declarations for function wrappers returning errors.
188+
for _, fn := range src.Funcs {
189+
if !fnNeedErrWrapper(fn) {
190+
continue
191+
}
192+
fmt.Fprintf(w, "%s %s(%s);\n", fn.Ret.Type, fnCErrWrapperName(fn), fnCErrWrapperParams(fn, false))
160193
}
194+
fmt.Fprintf(w, "\n")
195+
fmt.Fprintf(w, "#endif // MKCGO_H\n")
196+
}
197+
198+
// generateC creates the C source file content.
199+
func generateC(src *mkcgo.Source, w io.Writer) {
200+
// Header and includes.
201+
fmt.Fprintf(w, "// Code generated by mkcgo. DO NOT EDIT.\n\n")
202+
161203
fmt.Fprintf(w, "#include <stddef.h>\n")
162204
fmt.Fprintf(w, "#include <stdlib.h>\n")
163205
fmt.Fprintf(w, "#include <stdint.h>\n")
164206
fmt.Fprintf(w, "#include <stdio.h>\n")
207+
for _, file := range src.Files {
208+
fmt.Fprintf(w, "#include %q\n", file)
209+
}
210+
if *includeHeader != "" {
211+
fmt.Fprintf(w, "#include \"%s\"\n", *includeHeader)
212+
}
213+
fmt.Fprintf(w, "#include \"%s\"\n", autogeneratedFileName(".h"))
165214
fmt.Fprintf(w, "\n")
166215

167216
// Platform-specific includes.
@@ -238,6 +287,10 @@ func generateC(src *mkcgo.Source, w io.Writer) {
238287
}
239288

240289
// Generate C function wrappers.
290+
typedefs := make(map[string]string, len(src.TypeDefs))
291+
for _, def := range src.TypeDefs {
292+
typedefs[def.Name] = def.Type
293+
}
241294
for _, fn := range src.Funcs {
242295
if fn.Variadic() {
243296
// cgo doesn't support variadic functions
@@ -250,11 +303,12 @@ func generateC(src *mkcgo.Source, w io.Writer) {
250303
fmt.Fprintf(w, "}\n\n")
251304
}
252305
generateCFn(fn, w)
306+
generateCFnErrorWrapper(typedefs, fn, w)
253307
}
254308
}
255309

256310
// generateGoFn generates Go function f.
257-
func generateGoFn(typedefs map[string]string, fn *mkcgo.Func, w io.Writer) {
311+
func generateGoFn(fn *mkcgo.Func, w io.Writer) {
258312
fnCall := fmt.Sprintf("C.%s(%s)", fn.CName, fnToGoArgs(fn))
259313
// Function definition
260314
fmt.Fprintf(w, "func %s(%s)", fn.GoName, fnToGoParams(fn))
@@ -296,21 +350,13 @@ func generateGoFn(typedefs map[string]string, fn *mkcgo.Func, w io.Writer) {
296350
fmt.Fprintf(w, "}\n\n")
297351
return
298352
}
299-
fmt.Fprintf(w, "\t_ret := C.%s(%s)\n", fn.CName, fnToGoArgs(fn))
300-
301-
// Error handling
302-
errCond := "<= 0"
303-
if fn.ErrCond != "" {
304-
errCond = fn.ErrCond
305-
} else if strings.Contains(goType, "unsafe.Pointer") {
306-
errCond = "== nil"
307-
} else if typ, ok := typedefs[goType]; ok && typ == "void*" {
308-
errCond = "== nil"
353+
fmt.Fprintf(w, "\tvar _err C.%s\n", mkcgoErrState)
354+
fmt.Fprintf(w, "\t_ret := C.%s(", fnCErrWrapperName(fn))
355+
args := fnToGoArgs(fn)
356+
if len(args) > 0 {
357+
args += ", "
309358
}
310-
fmt.Fprintf(w, "\tvar _err error\n")
311-
fmt.Fprintf(w, "\tif _ret %s {\n", errCond)
312-
fmt.Fprintf(w, "\t\t_err = newOpenSSLError(\"%s\")\n", fn.CName)
313-
fmt.Fprintf(w, "\t}\n")
359+
fmt.Fprintf(w, "%s%s(&_err))\n", args, mkcgoNoEscape)
314360

315361
// Return the value
316362
fmt.Fprintf(w, "\treturn ")
@@ -322,16 +368,38 @@ func generateGoFn(typedefs map[string]string, fn *mkcgo.Func, w io.Writer) {
322368
} else {
323369
fmt.Fprintf(w, "_ret")
324370
}
325-
fmt.Fprintf(w, ", _err\n")
371+
fmt.Fprintf(w, ", newMkcgoErr(%q, _err)\n", fn.CName)
326372
fmt.Fprintf(w, "}\n\n")
327373
}
328374

329375
func generateCFn(fn *mkcgo.Func, w io.Writer) {
330-
fmt.Fprintf(w, "%s %s(%s) {\n\t", fn.Ret.Type, fn.CName, fnToCArgs(fn, true))
376+
fmt.Fprintf(w, "%s %s(%s) {\n\t", fn.Ret.Type, fn.CName, fnToCArgs(fn, true, true))
331377
if !retIsVoid(fn.Ret) {
332378
fmt.Fprintf(w, "return ")
333379
}
334-
fmt.Fprintf(w, "_g_%s(%s);\n", fn.ImportName, fnToCArgs(fn, false))
380+
fmt.Fprintf(w, "_g_%s(%s);\n", fn.ImportName, fnToCArgs(fn, false, true))
381+
fmt.Fprintf(w, "}\n\n")
382+
}
383+
384+
// generateCFnErrorWrapper generates C function wrapper for function f
385+
// that returns an error state.
386+
func generateCFnErrorWrapper(typedefs map[string]string, fn *mkcgo.Func, w io.Writer) {
387+
if !fnNeedErrWrapper(fn) {
388+
return
389+
}
390+
fmt.Fprintf(w, "%s %s(%s) {\n", fn.Ret.Type, fnCErrWrapperName(fn), fnCErrWrapperParams(fn, true))
391+
fmt.Fprintf(w, "\tmkcgo_err_clear();\n") // clear any previous error
392+
fmt.Fprintf(w, "\t%s _ret = _g_%s(%s);\n", fn.Ret.Type, fn.ImportName, fnToCArgs(fn, false, true))
393+
errCond := "<= 0"
394+
if fn.ErrCond != "" {
395+
errCond = fn.ErrCond
396+
} else if strings.Contains(fn.Ret.Type, "*") {
397+
errCond = "== NULL"
398+
} else if typ, ok := typedefs[fn.Ret.Type]; ok && typ == "void*" {
399+
errCond = "== NULL"
400+
}
401+
fmt.Fprintf(w, "\tif (_ret %s) *_err_state = mkcgo_err_retrieve();\n", errCond)
402+
fmt.Fprintf(w, "\treturn _ret;\n")
335403
fmt.Fprintf(w, "}\n\n")
336404
}
337405

@@ -436,12 +504,15 @@ func cTypeToGo(t string, cgo bool) (string, bool) {
436504
}
437505

438506
// paramToC returns C source code of parameter p.
439-
func paramToC(i int, p *mkcgo.Param, addType bool) string {
507+
func paramToC(i int, p *mkcgo.Param, addType, addName bool) string {
508+
if p.Type == "..." {
509+
return ""
510+
}
440511
var s string
441512
if addType {
442513
s += p.Type
443514
}
444-
if p.Type != "void" && p.Type != "..." {
515+
if addName && p.Type != "void" {
445516
if len(s) > 0 {
446517
s += " "
447518
}
@@ -470,9 +541,9 @@ func fnToGoArgs(fn *mkcgo.Func) string {
470541
}
471542

472543
// fnToCArgs returns source code for C parameters for function f.
473-
func fnToCArgs(fn *mkcgo.Func, addType bool) string {
544+
func fnToCArgs(fn *mkcgo.Func, addType, addName bool) string {
474545
return join(fn.Params, func(i int, p *mkcgo.Param) string {
475-
return paramToC(i, p, addType)
546+
return paramToC(i, p, addType, addName)
476547
}, ", ")
477548
}
478549

@@ -492,3 +563,34 @@ func join(ps []*mkcgo.Param, fn func(int, *mkcgo.Param) string, sep string) stri
492563
}
493564
return strings.Join(params, sep)
494565
}
566+
567+
const mkcgoNoEscape = "mkcgoNoEscape"
568+
const mkcgoErrState = "mkcgo_err_state"
569+
570+
// fnCErrWrapperParams returns source code for C parameters for function f
571+
// with the error state added as the last parameter.
572+
func fnCErrWrapperParams(fn *mkcgo.Func, addName bool) string {
573+
errArg := mkcgoErrState + " *"
574+
if addName {
575+
errArg += "_err_state"
576+
}
577+
args := fnToCArgs(fn, true, addName)
578+
if len(args) == 0 {
579+
args = errArg
580+
} else if args == "void" {
581+
args = errArg
582+
} else {
583+
args += ", " + errArg
584+
}
585+
return args
586+
}
587+
588+
// fnCErrWrapperName returns the name of the error wrapper function for function f.
589+
func fnCErrWrapperName(fn *mkcgo.Func) string {
590+
return "_mkcgo_err_" + fn.CName
591+
}
592+
593+
// fnNeedErrWrapper reports whether function fn needs an error wrapper.
594+
func fnNeedErrWrapper(fn *mkcgo.Func) bool {
595+
return !fn.NoError && !retIsVoid(fn.Ret)
596+
}

cmd/mkcgo/main.go

+23-15
Original file line numberDiff line numberDiff line change
@@ -40,45 +40,53 @@ func main() {
4040
log.Fatal(err)
4141
}
4242

43-
var gobuf, go124buf, cbuf bytes.Buffer
43+
var gobuf, go124buf, hbuf, cbuf bytes.Buffer
4444
generateGo(src, &gobuf)
4545
generateGo124(src, &go124buf)
46+
generateCHeader(src, &hbuf)
4647
generateC(src, &cbuf)
4748

4849
// Format the generated Go source code.
4950
godata := goformat(gobuf.Bytes())
5051
go124data := goformat(go124buf.Bytes())
5152

52-
var baseName string
53-
if *fileName == "" {
54-
baseName = "mkcgo"
55-
} else {
56-
baseName = strings.TrimSuffix(*fileName, ".go")
57-
}
58-
5953
for _, d := range []struct {
60-
name string
61-
data []byte
54+
suffix string
55+
data []byte
6256
}{
63-
{baseName + ".go", godata},
64-
{baseName + "_go124.go", go124data},
65-
{baseName + ".c", cbuf.Bytes()},
57+
{".go", godata},
58+
{"_go124.go", go124data},
59+
{".h", hbuf.Bytes()},
60+
{".c", cbuf.Bytes()},
6661
} {
62+
name := autogeneratedFileName(d.suffix)
6763
var err error
6864
if *fileName == "" {
6965
// Write output. If no explicit output file is specified,
7066
// // write both Go and C output to stdout.
71-
os.Stdout.WriteString("// === " + d.name + " ===\n\n")
67+
os.Stdout.WriteString("// === " + name + " ===\n\n")
7268
_, err = os.Stdout.Write(d.data)
7369
} else {
74-
err = os.WriteFile(d.name, d.data, 0o644)
70+
err = os.WriteFile(name, d.data, 0o644)
7571
}
7672
if err != nil {
7773
log.Fatal(err)
7874
}
7975
}
8076
}
8177

78+
// autogeneratedFileName returns the name of the autogenerated file
79+
// using the provided suffix.
80+
func autogeneratedFileName(suffix string) string {
81+
var baseName string
82+
if *fileName == "" {
83+
baseName = "mkcgo"
84+
} else {
85+
baseName = strings.TrimSuffix(*fileName, ".go")
86+
}
87+
return baseName + suffix
88+
}
89+
8290
func writeTempSourceFile(data []byte) (string, error) {
8391
f, err := os.CreateTemp("", "mkcgo-generated-*.go")
8492
if err != nil {

0 commit comments

Comments
 (0)