Skip to content

Commit 744c0be

Browse files
authored
Merge pull request #6 from cpunion/fix
better func names, fix ref counting and memory free, add tests
2 parents 3b8cc1d + e8c024e commit 744c0be

15 files changed

+319
-76
lines changed

_demo/autoderef/autoderef.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func main() {
1313
gp.Initialize()
1414
defer gp.Finalize()
1515
fooMod := foo.InitFooModule()
16-
gp.GetModuleDict().Set(gp.MakeStr("foo").Object, fooMod.Object)
16+
gp.GetModuleDict().SetString("foo", fooMod)
1717

1818
Main1(fooMod)
1919
Main2()
@@ -22,21 +22,21 @@ func main() {
2222

2323
func Main1(fooMod gp.Module) {
2424
fmt.Printf("=========== Main1 ==========\n")
25-
sum := fooMod.Call("add", gp.MakeLong(1), gp.MakeLong(2)).AsLong()
25+
sum := fooMod.Call("add", 1, 2).AsLong()
2626
fmt.Printf("Sum of 1 + 2: %d\n", sum.Int64())
2727

2828
dict := fooMod.Dict()
2929
Point := dict.Get(gp.MakeStr("Point")).AsFunc()
3030

31-
point := Point.Call(gp.MakeLong(3), gp.MakeLong(4))
31+
point := Point.Call(3, 4)
3232
fmt.Printf("dir(point): %v\n", point.Dir())
33-
fmt.Printf("x: %v, y: %v\n", point.GetAttr("x"), point.GetAttr("y"))
33+
fmt.Printf("x: %v, y: %v\n", point.Attr("x"), point.Attr("y"))
3434

3535
distance := point.Call("distance").AsFloat()
3636
fmt.Printf("Distance of 3 * 4: %f\n", distance.Float64())
3737

38-
point.Call("move", gp.MakeFloat(1), gp.MakeFloat(2))
39-
fmt.Printf("x: %v, y: %v\n", point.GetAttr("x"), point.GetAttr("y"))
38+
point.Call("move", 1, 2)
39+
fmt.Printf("x: %v, y: %v\n", point.Attr("x"), point.Attr("y"))
4040

4141
distance = point.Call("distance").AsFloat()
4242
fmt.Printf("Distance of 4 * 6: %f\n", distance.Float64())
@@ -45,7 +45,7 @@ func Main1(fooMod gp.Module) {
4545

4646
func Main2() {
4747
fmt.Printf("=========== Main2 ==========\n")
48-
gp.RunString(`
48+
_ = gp.RunString(`
4949
import foo
5050
point = foo.Point(3, 4)
5151
print("dir(point):", dir(point))
@@ -92,7 +92,7 @@ for i in range(10):
9292
fmt.Printf("Iteration %d in python\n", i+1)
9393
}
9494

95-
memory_allocation_test := mod.GetFuncAttr("memory_allocation_test")
95+
memory_allocation_test := mod.AttrFunc("memory_allocation_test")
9696

9797
for i := 0; i < 100; i++ {
9898
// 100MB every time

_demo/gradio/gradio.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func main() {
6060
})
6161
textbox := gr.Call("Textbox")
6262
examples := gr.Call("Examples", [][]string{{"Chicago"}, {"Little Rock"}, {"San Francisco"}}, textbox)
63-
dataset := examples.GetAttr("dataset")
63+
dataset := examples.Attr("dataset")
6464
dropdown.Call("change", fn, dropdown, dataset)
6565
})
6666
demo.Call("launch")

_demo/plot/plot.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@ func main() {
2222
gp.Initialize()
2323
defer gp.Finalize()
2424
plt := Plt()
25-
plt.Plot(gp.MakeTuple(5, 10), gp.MakeTuple(10, 15), gp.KwArgs{"color": "red"})
25+
plt.Plot([]int{5, 10}, []int{10, 15}, gp.KwArgs{"color": "red"})
2626
plt.Show()
2727
}

adap_go.go

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ func AllocCStr(s string) *C.char {
1818
return C.CString(s)
1919
}
2020

21+
func AllocCStrDontFree(s string) *C.char {
22+
return C.CString(s)
23+
}
24+
2125
func AllocWCStr(s string) *C.wchar_t {
2226
runes := []rune(s)
2327
wchars := make([]uint16, len(runes)+1)

bytes.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ func BytesFromStr(s string) Bytes {
2222

2323
func MakeBytes(bytes []byte) Bytes {
2424
ptr := C.CBytes(bytes)
25-
return newBytes(C.PyBytes_FromStringAndSize((*C.char)(ptr), C.Py_ssize_t(len(bytes))))
25+
o := C.PyBytes_FromStringAndSize((*C.char)(ptr), C.Py_ssize_t(len(bytes)))
26+
C.free(unsafe.Pointer(ptr))
27+
return newBytes(o)
2628
}
2729

2830
func (b Bytes) Bytes() []byte {
29-
var p *byte
30-
var l int
31+
p := (*byte)(unsafe.Pointer(C.PyBytes_AsString(b.obj)))
32+
l := int(C.PyBytes_Size(b.obj))
3133
return C.GoBytes(unsafe.Pointer(p), C.int(l))
3234
}
3335

dict.go

+24-12
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ package gp
44
#include <Python.h>
55
*/
66
import "C"
7-
import "fmt"
7+
import (
8+
"fmt"
9+
"unsafe"
10+
)
811

912
type Dict struct {
1013
Object
@@ -47,26 +50,32 @@ func (d Dict) Get(key Objecter) Object {
4750
return newObject(v)
4851
}
4952

50-
func (d Dict) Set(key, value Object) {
51-
C.Py_IncRef(key.obj)
52-
C.Py_IncRef(value.obj)
53-
C.PyDict_SetItem(d.obj, key.obj, value.obj)
53+
func (d Dict) Set(key, value Objecter) {
54+
keyObj := key.Obj()
55+
valueObj := value.Obj()
56+
C.PyDict_SetItem(d.obj, keyObj, valueObj)
5457
}
5558

56-
func (d Dict) SetString(key string, value Object) {
57-
C.Py_IncRef(value.obj)
58-
C.PyDict_SetItemString(d.obj, AllocCStr(key), value.obj)
59+
func (d Dict) SetString(key string, value Objecter) {
60+
valueObj := value.Obj()
61+
ckey := AllocCStr(key)
62+
r := C.PyDict_SetItemString(d.obj, ckey, valueObj)
63+
C.free(unsafe.Pointer(ckey))
64+
if r != 0 {
65+
panic(fmt.Errorf("failed to set item string: %v", r))
66+
}
5967
}
6068

6169
func (d Dict) GetString(key string) Object {
62-
v := C.PyDict_GetItemString(d.obj, AllocCStr(key))
70+
ckey := AllocCStr(key)
71+
v := C.PyDict_GetItemString(d.obj, ckey)
6372
C.Py_IncRef(v)
73+
C.free(unsafe.Pointer(ckey))
6474
return newObject(v)
6575
}
6676

67-
func (d Dict) Del(key Object) {
68-
C.PyDict_DelItem(d.obj, key.obj)
69-
C.Py_DecRef(key.obj)
77+
func (d Dict) Del(key Objecter) {
78+
C.PyDict_DelItem(d.obj, key.Obj())
7079
}
7180

7281
func (d Dict) ForEach(fn func(key, value Object)) {
@@ -84,6 +93,9 @@ func (d Dict) ForEach(fn func(key, value Object)) {
8493
C.Py_IncRef(item)
8594
key := C.PyTuple_GetItem(item, 0)
8695
value := C.PyTuple_GetItem(item, 1)
96+
C.Py_IncRef(key)
97+
C.Py_IncRef(value)
98+
C.Py_DecRef(item)
8799
fn(newObject(key), newObject(value))
88100
}
89101
}

float.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@ func (f Float) Float64() float64 {
2222
}
2323

2424
func (f Float) IsInteger() Bool {
25-
fn := Cast[Func](f.GetAttr("is_integer"))
25+
fn := Cast[Func](f.Attr("is_integer"))
2626
return Cast[Bool](fn.callNoArgs())
2727
}

function.go

+10-12
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func wrapperInit(self, args *C.PyObject) C.int {
154154
}
155155

156156
//export getterMethod
157-
func getterMethod(self *C.PyObject, closure unsafe.Pointer, methodId C.int) *C.PyObject {
157+
func getterMethod(self *C.PyObject, _closure unsafe.Pointer, methodId C.int) *C.PyObject {
158158
typeMeta := typeMetaMap[(*C.PyObject)(unsafe.Pointer(self.ob_type))]
159159
if typeMeta == nil {
160160
SetError(fmt.Errorf("type %v not registered", FromPy(self)))
@@ -177,7 +177,7 @@ func getterMethod(self *C.PyObject, closure unsafe.Pointer, methodId C.int) *C.P
177177
}
178178

179179
//export setterMethod
180-
func setterMethod(self, value *C.PyObject, closure unsafe.Pointer, methodId C.int) C.int {
180+
func setterMethod(self, value *C.PyObject, _closure unsafe.Pointer, methodId C.int) C.int {
181181
typeMeta := typeMetaMap[(*C.PyObject)(unsafe.Pointer(self.ob_type))]
182182
if typeMeta == nil {
183183
SetError(fmt.Errorf("type %v not registered", FromPy(self)))
@@ -260,6 +260,7 @@ func wrapperMethod_(typeMeta *typeMeta, methodMeta *slotMeta, self, args *C.PyOb
260260

261261
for i := 0; i < int(argc); i++ {
262262
arg := C.PyTuple_GetItem(args, C.Py_ssize_t(i))
263+
C.Py_IncRef(arg)
263264
argType := methodType.In(i + argIndex)
264265
argPy := FromPy(arg)
265266
goValue := reflect.New(argType).Elem()
@@ -316,7 +317,7 @@ func getMethods_(t reflect.Type, methods map[uint]*slotMeta) (ret []C.PyMethodDe
316317
methodPtr := C.wrapperMethods[methodId]
317318

318319
ret = append(ret, C.PyMethodDef{
319-
ml_name: C.CString(pythonName),
320+
ml_name: AllocCStrDontFree(pythonName),
320321
ml_meth: (C.PyCFunction)(unsafe.Pointer(methodPtr)),
321322
ml_flags: C.METH_VARARGS,
322323
ml_doc: nil,
@@ -402,7 +403,7 @@ func getMembers(t reflect.Type, methods map[uint]*slotMeta) (members *C.PyMember
402403
if memberType != -1 {
403404
// create as member variable for C-compatible types
404405
membersList = append(membersList, C.PyMemberDef{
405-
name: C.CString(pythonName),
406+
name: AllocCStrDontFree(pythonName),
406407
_type: memberType,
407408
offset: C.Py_ssize_t(baseOffset + field.Offset),
408409
})
@@ -429,7 +430,7 @@ func getMembers(t reflect.Type, methods map[uint]*slotMeta) (members *C.PyMember
429430
index: i,
430431
}
431432
getsetsList = append(getsetsList, C.PyGetSetDef{
432-
name: C.CString(pythonName),
433+
name: AllocCStrDontFree(pythonName),
433434
get: C.getterMethods[getId],
434435
set: C.setterMethods[setId],
435436
doc: nil,
@@ -480,9 +481,6 @@ func AddType[T any](m Module, init any, name string, doc string) Object {
480481
methods: make(map[uint]*slotMeta),
481482
}
482483

483-
cname := C.CString(name)
484-
defer C.free(unsafe.Pointer(cname))
485-
486484
slots := make([]C.PyType_Slot, 0)
487485
if init != nil {
488486
slots = append(slots, C.PyType_Slot{slot: C.Py_tp_init, pfunc: unsafe.Pointer(C.wrapperInit)})
@@ -512,7 +510,7 @@ func AddType[T any](m Module, init any, name string, doc string) Object {
512510
}
513511

514512
spec := &C.PyType_Spec{
515-
name: cname,
513+
name: C.CString(name),
516514
basicsize: C.int(unsafe.Sizeof(wrapper)),
517515
flags: C.Py_TPFLAGS_DEFAULT,
518516
slots: slotsPtr,
@@ -526,7 +524,7 @@ func AddType[T any](m Module, init any, name string, doc string) Object {
526524
typeMetaMap[typeObj] = meta
527525
pyTypeMap[ty] = typeObj
528526

529-
if C.PyModule_AddObject(m.obj, cname, typeObj) < 0 {
527+
if C.PyModule_AddObject(m.obj, C.CString(name), typeObj) < 0 {
530528
C.Py_DecRef(typeObj)
531529
panic(fmt.Sprintf("Failed to add type %s to module", name))
532530
}
@@ -599,14 +597,14 @@ func (m Module) AddMethod(name string, fn any, doc string) Func {
599597

600598
func SetError(err error) {
601599
errStr := C.CString(err.Error())
602-
defer C.free(unsafe.Pointer(errStr))
603600
C.PyErr_SetString(C.PyExc_RuntimeError, errStr)
601+
C.free(unsafe.Pointer(errStr))
604602
}
605603

606604
func SetTypeError(err error) {
607605
errStr := C.CString(err.Error())
608-
defer C.free(unsafe.Pointer(errStr))
609606
C.PyErr_SetString(C.PyExc_TypeError, errStr)
607+
C.free(unsafe.Pointer(errStr))
610608
}
611609

612610
// FetchError returns the current Python error as a Go error

list.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,16 @@ func (l List) GetItem(index int) Object {
2828
return newObject(v)
2929
}
3030

31-
func (l List) SetItem(index int, item Object) {
32-
C.PyList_SetItem(l.obj, C.Py_ssize_t(index), item.obj)
31+
func (l List) SetItem(index int, item Objecter) {
32+
itemObj := item.Obj()
33+
C.Py_IncRef(itemObj)
34+
C.PyList_SetItem(l.obj, C.Py_ssize_t(index), itemObj)
3335
}
3436

3537
func (l List) Len() int {
3638
return int(C.PyList_Size(l.obj))
3739
}
3840

39-
func (l List) Append(obj Object) {
40-
C.PyList_Append(l.obj, obj.obj)
41+
func (l List) Append(obj Objecter) {
42+
C.PyList_Append(l.obj, obj.Obj())
4143
}

long.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package gp
44
#include <Python.h>
55
*/
66
import "C"
7+
import "unsafe"
78

89
type Long struct {
910
Object
@@ -35,7 +36,9 @@ func LongFromFloat64(v float64) Long {
3536

3637
func LongFromString(s string, base int) Long {
3738
cstr := AllocCStr(s)
38-
return newLong(C.PyLong_FromString(cstr, nil, C.int(base)))
39+
o := C.PyLong_FromString(cstr, nil, C.int(base))
40+
C.free(unsafe.Pointer(cstr))
41+
return newLong(o)
3942
}
4043

4144
func LongFromUnicode(u Object, base int) Long {

module.go

+10-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package gp
44
#include <Python.h>
55
*/
66
import "C"
7+
import "unsafe"
78

89
type Module struct {
910
Object
@@ -14,7 +15,9 @@ func newModule(obj *PyObject) Module {
1415
}
1516

1617
func ImportModule(name string) Module {
17-
mod := C.PyImport_ImportModule(AllocCStr(name))
18+
cname := AllocCStr(name)
19+
mod := C.PyImport_ImportModule(cname)
20+
C.free(unsafe.Pointer(cname))
1821
return newModule(mod)
1922
}
2023

@@ -27,11 +30,15 @@ func (m Module) Dict() Dict {
2730
}
2831

2932
func (m Module) AddObject(name string, obj Object) int {
30-
return int(C.PyModule_AddObject(m.obj, AllocCStr(name), obj.obj))
33+
cname := AllocCStr(name)
34+
r := int(C.PyModule_AddObject(m.obj, cname, obj.obj))
35+
C.free(unsafe.Pointer(cname))
36+
return r
3137
}
3238

3339
func CreateModule(name string) Module {
34-
return newModule(C.PyModule_New(AllocCStr(name)))
40+
mod := C.PyModule_New(AllocCStrDontFree(name))
41+
return newModule(mod)
3542
}
3643

3744
func GetModuleDict() Dict {

0 commit comments

Comments
 (0)