Skip to content

Commit f4bda47

Browse files
committed
support buffer protocol
1 parent 97cc18b commit f4bda47

File tree

2 files changed

+97
-38
lines changed

2 files changed

+97
-38
lines changed

dartsclone/_dartsclone.pyx

+86-34
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from libc.stdlib cimport malloc, free
1+
from libc.stdlib cimport calloc, free
2+
3+
cdef extern from "Python.h":
4+
ctypedef struct PyObject
5+
int PyObject_GetBuffer(PyObject *exporter, Py_buffer *view, int flags)
6+
void PyBuffer_Release(Py_buffer *view)
7+
const int PyBUF_C_CONTIGUOUS
28

39

410
cdef class DoubleArray:
@@ -41,26 +47,45 @@ cdef class DoubleArray:
4147
lengths = None,
4248
values = None):
4349
cdef size_t num_keys = len(keys)
44-
cdef const char** _keys = <const char**> malloc(num_keys * sizeof(char*))
50+
cdef const char** _keys = NULL
51+
cdef Py_buffer* _buf = NULL
4552
cdef size_t *_lengths = NULL
4653
cdef int *_values = NULL
47-
for i, key in enumerate(keys):
48-
_keys[i] = key
49-
if lengths is not None:
50-
_lengths = <size_t *> malloc(num_keys * sizeof(size_t))
51-
for i, length in enumerate(lengths):
52-
_lengths[i] = length
53-
if values is not None:
54-
_values = <int *> malloc(num_keys * sizeof(int))
55-
for i, value in enumerate(values):
56-
_values[i] = value
54+
5755
try:
56+
_keys = <const char**> calloc(num_keys, sizeof(char*))
57+
if _keys == NULL:
58+
raise MemoryError("failed to allocate memory for key array")
59+
_buf = <Py_buffer *> calloc(num_keys, sizeof(Py_buffer))
60+
if _buf == NULL:
61+
raise MemoryError("failed to allocate memory for buffer")
62+
for i, key in enumerate(keys):
63+
if PyObject_GetBuffer(<PyObject *>key, &_buf[i], PyBUF_C_CONTIGUOUS) < 0:
64+
return
65+
_keys[i] = <const char *> _buf[i].buf
66+
if lengths is not None:
67+
_lengths = <size_t *> calloc(num_keys, sizeof(size_t))
68+
if _lengths == NULL:
69+
raise MemoryError("failed to allocate memory for length array")
70+
for i, length in enumerate(lengths):
71+
_lengths[i] = length
72+
if values is not None:
73+
_values = <int *> calloc(num_keys, sizeof(int))
74+
if _values == NULL:
75+
raise MemoryError("failed to allocate memory for value array")
76+
for i, value in enumerate(values):
77+
_values[i] = value
5878
self.wrapped.build(num_keys, _keys, <const size_t*> _lengths, <const int*> _values, NULL)
5979
finally:
60-
free(_keys)
61-
if lengths is not None:
80+
if _keys != NULL:
81+
free(_keys)
82+
if _buf != NULL:
83+
for i in range(num_keys):
84+
PyBuffer_Release(&_buf[i])
85+
free(_buf)
86+
if _lengths != NULL:
6287
free(_lengths)
63-
if values is not None:
88+
if _values != NULL:
6489
free(_values)
6590

6691
def open(self, file_name,
@@ -88,39 +113,66 @@ cdef class DoubleArray:
88113
size_t length = 0,
89114
size_t node_pos = 0,
90115
pair_type=True):
91-
cdef const char *_key = key
92-
if pair_type:
93-
return self.__exact_match_search_pair_type(_key, length, node_pos)
94-
else:
95-
return self.__exact_match_search(_key, length, node_pos)
116+
cdef Py_buffer buf
117+
if PyObject_GetBuffer(<PyObject *>key, &buf, PyBUF_C_CONTIGUOUS) < 0:
118+
return
119+
try:
120+
if length == 0:
121+
if buf.len == 0:
122+
raise ValueError("buffer cannot be empty")
123+
length = buf.len
124+
if pair_type:
125+
return self.__exact_match_search_pair_type(<const char *>buf.buf, length, node_pos)
126+
else:
127+
return self.__exact_match_search(<const char *>buf.buf, length, node_pos)
128+
finally:
129+
PyBuffer_Release(&buf)
96130

97131
def common_prefix_search(self, key,
98132
size_t max_num_results = 0,
99133
size_t length = 0,
100134
size_t node_pos = 0,
101135
pair_type=True):
102-
cdef const char *_key = key
103-
if max_num_results == 0:
104-
max_num_results = len(key)
105-
if pair_type:
106-
return self.__common_prefix_search_pair_type(_key, max_num_results, length, node_pos)
107-
else:
108-
return self.__common_prefix_search(_key, max_num_results, length, node_pos)
136+
cdef Py_buffer buf
137+
if PyObject_GetBuffer(<PyObject *>key, &buf, PyBUF_C_CONTIGUOUS) < 0:
138+
return
139+
try:
140+
if length == 0:
141+
if buf.len == 0:
142+
raise ValueError("buffer cannot be empty")
143+
length = buf.len
144+
if max_num_results == 0:
145+
max_num_results = len(key)
146+
if pair_type:
147+
return self.__common_prefix_search_pair_type(<const char *>buf.buf, max_num_results, length, node_pos)
148+
else:
149+
return self.__common_prefix_search(<const char *>buf.buf, max_num_results, length, node_pos)
150+
finally:
151+
PyBuffer_Release(&buf)
109152

110153
def traverse(self, key,
111154
size_t node_pos,
112155
size_t key_pos,
113156
size_t length = 0):
114-
cdef const char *_key = key
157+
cdef Py_buffer buf
115158
cdef int result
116-
with nogil:
117-
result = self.wrapped.traverse(_key, node_pos, key_pos, length)
118-
return result
159+
if PyObject_GetBuffer(<PyObject *>key, &buf, PyBUF_C_CONTIGUOUS) < 0:
160+
return
161+
try:
162+
if length == 0:
163+
if buf.len == 0:
164+
raise ValueError("buffer cannot be empty")
165+
length = buf.len
166+
with nogil:
167+
result = self.wrapped.traverse(<const char *>buf.buf, node_pos, key_pos, length)
168+
return result
169+
finally:
170+
PyBuffer_Release(&buf)
119171

120172
def __exact_match_search(self, const char *key,
121173
size_t length = 0,
122174
size_t node_pos = 0):
123-
cdef int result
175+
cdef int result = 0
124176
with nogil:
125177
self.wrapped.exact_match_search(key, result, length, node_pos)
126178
return result
@@ -137,7 +189,7 @@ cdef class DoubleArray:
137189
size_t max_num_results,
138190
size_t length,
139191
size_t node_pos):
140-
cdef int *results = <int *> malloc(max_num_results * sizeof(int))
192+
cdef int *results = <int *> calloc(max_num_results, sizeof(int))
141193
cdef int result_len
142194
try:
143195
with nogil:
@@ -153,7 +205,7 @@ cdef class DoubleArray:
153205
size_t max_num_results,
154206
size_t length,
155207
size_t node_pos):
156-
cdef result_pair_type *results = <result_pair_type *> malloc(max_num_results * sizeof(result_pair_type))
208+
cdef result_pair_type *results = <result_pair_type *> calloc(max_num_results, sizeof(result_pair_type))
157209
cdef result_pair_type result
158210
cdef int result_len
159211
try:

test/test_darts.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class DoubleArrayTest(unittest.TestCase):
1212
def test_darts_no_values(self):
1313
keys = ['test', 'テスト', 'テストケース']
1414
darts = DoubleArray()
15-
darts.build(sorted([key.encode() for key in keys]))
15+
darts.build([key.encode() for key in keys])
1616
self.assertEqual(1, darts.exact_match_search('テスト'.encode(), pair_type=False))
1717
self.assertEqual(0, darts.common_prefix_search('testcase'.encode(), pair_type=False)[0])
1818
self.assertEqual(0, darts.exact_match_search('test'.encode(), pair_type=False))
@@ -21,7 +21,7 @@ def test_darts_no_values(self):
2121
def test_darts_with_values(self):
2222
keys = ['test', 'テスト', 'テストケース']
2323
darts = DoubleArray()
24-
darts.build(sorted([key.encode() for key in keys]), values=[3, 5, 1])
24+
darts.build([key.encode() for key in keys], values=[3, 5, 1])
2525
self.assertEqual(5, darts.exact_match_search('テスト'.encode(), pair_type=False))
2626
self.assertEqual(3, darts.common_prefix_search('testcase'.encode(), pair_type=False)[0])
2727
self.assertEqual(1, darts.exact_match_search('テストケース'.encode(), pair_type=False))
@@ -30,7 +30,7 @@ def test_darts_with_values(self):
3030
def test_darts_save(self):
3131
keys = ['test', 'テスト', 'テストケース']
3232
darts = DoubleArray()
33-
darts.build(sorted([key.encode() for key in keys]), values=[3, 5, 1])
33+
darts.build([key.encode() for key in keys], values=[3, 5, 1])
3434
with tempfile.NamedTemporaryFile('wb') as output_file:
3535
darts.save(output_file.name)
3636
output_file.flush()
@@ -54,13 +54,20 @@ def test_darts_pickle(self):
5454
def test_darts_array(self):
5555
keys = ['test', 'テスト', 'テストケース']
5656
darts = DoubleArray()
57-
darts.build(sorted([key.encode() for key in keys]), values=[3, 5, 1])
57+
darts.build([key.encode() for key in keys], values=[3, 5, 1])
5858
array = darts.array()
5959
darts = DoubleArray()
6060
darts.set_array(array)
6161
self.assertEqual(5, darts.exact_match_search('テスト'.encode(), pair_type=False))
6262
self.assertEqual(3, darts.common_prefix_search('testcase'.encode(), pair_type=False)[0])
6363

64+
def test_darts_buffers(self):
65+
keys = ['test', 'テスト', 'テストケース']
66+
darts = DoubleArray()
67+
darts.build([memoryview(key.encode()) for key in keys], values=[3, 5, 1])
68+
self.assertEqual(5, darts.exact_match_search(memoryview('テスト'.encode()), pair_type=False))
69+
self.assertEqual(3, darts.common_prefix_search(memoryview('testcase'.encode()), pair_type=False)[0])
70+
6471

6572
if __name__ == "__main__":
6673
unittest.main()

0 commit comments

Comments
 (0)