forked from gaolk/graph-database-benchmark
-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathbulk_insert.py
390 lines (325 loc) · 15.6 KB
/
bulk_insert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
import csv
import io
import os
import struct
from timeit import default_timer as timer
import click
import redis
# Global variables
CONFIGS = None # thresholds for batching Redis queries
NODE_DICT = {} # global node dictionary
TOP_NODE_ID = 0 # next ID to assign to a node
QUERY_BUF = None # Buffer for query being constructed
# Custom error class for invalid inputs
class CSVError(Exception):
pass
# Official enum support varies widely between 2.7 and 3.x, so we'll use a custom class
class Type:
NULL = 0
BOOL = 1
NUMERIC = 2
STRING = 3
# User-configurable thresholds for when to send queries to Redis
class Configs(object):
def __init__(self, max_token_count, max_buffer_size, max_token_size):
# Maximum number of tokens per query
# 1024 * 1024 is the hard-coded Redis maximum. We'll set a slightly lower limit so
# that we can safely ignore tokens that aren't binary strings
# ("GRAPH.BULK", "BEGIN", graph name, counts)
self.max_token_count = min(max_token_count, 1024 * 1023)
# Maximum size in bytes per query
self.max_buffer_size = max_buffer_size * 1000000
# Maximum size in bytes per token
# 512 megabytes is a hard-coded Redis maximum
self.max_token_size = min(max_token_size * 1000000, 512 * 1000000)
# QueryBuffer is the class that processes input CSVs and emits their binary formats to the Redis client.
class QueryBuffer(object):
def __init__(self, graphname, client):
# Redis client and data for each query
self.client = client
# Sizes for buffer currently being constructed
self.redis_token_count = 0
self.buffer_size = 0
# The first query should include a "BEGIN" token
self.graphname = graphname
self.initial_query = True
self.node_count = 0
self.relation_count = 0
self.labels = [] # List containing all pending Label objects
self.reltypes = [] # List containing all pending RelationType objects
self.nodes_created = 0 # Total number of nodes created
self.relations_created = 0 # Total number of relations created
# Send all pending inserts to Redis
def send_buffer(self):
# Do nothing if we have no entities
if self.node_count == 0 and self.relation_count == 0:
return
args = [self.node_count, self.relation_count, len(self.labels),
len(self.reltypes)] + self.labels + self.reltypes
# Prepend a "BEGIN" token if this is the first query
if self.initial_query:
args.insert(0, "BEGIN")
self.initial_query = False
result = self.client.execute_command("GRAPH.BULK", self.graphname, *args)
stats = result.split(', '.encode())
self.nodes_created += int(stats[0].split(' '.encode())[0])
self.relations_created += int(stats[1].split(' '.encode())[0])
self.clear_buffer()
# Delete all entities that have been inserted
def clear_buffer(self):
self.redis_token_count = 0
self.buffer_size = 0
# All constructed entities have been inserted, so clear buffers
self.node_count = 0
self.relation_count = 0
del self.labels[:]
del self.reltypes[:]
def report_completion(self, runtime):
print("Construction of graph '%s' complete: %d nodes created, %d relations created in %f seconds"
% (self.graphname, self.nodes_created, self.relations_created, runtime))
# Superclass for label and relation CSV files
class EntityFile(object):
def __init__(self, filename):
# The label or relation type string is the basename of the file
self.entity_str = os.path.splitext(os.path.basename(filename))[0].encode('utf-8')
# Input file handling
self.infile = io.open(filename, 'rt', encoding='utf-8')
# Initialize CSV reader that ignores leading whitespace in each field
# and does not modify input quote characters
self.reader = csv.reader(self.infile, skipinitialspace=True, quoting=csv.QUOTE_NONE)
self.prop_offset = 0 # Starting index of properties in row
self.prop_count = 0 # Number of properties per entity
self.packed_header = ""
self.binary_entities = []
self.binary_size = 0 # size of binary token
self.count_entities() # number of entities/row in file.
# Count number of rows in file.
def count_entities(self):
self.entities_count = 0
self.entities_count = sum(1 for line in self.infile)
# discard header row
self.entities_count -= 1
# seek back
self.infile.seek(0)
return self.entities_count
# Simple input validations for each row of a CSV file
def validate_row(self, expected_col_count, row):
# Each row should have the same number of fields
if len(row) != expected_col_count:
raise CSVError("%s:%d Expected %d columns, encountered %d ('%s')"
% (self.infile.name, self.reader.line_num, expected_col_count, len(row), ','.join(row)))
# If part of a CSV file was sent to Redis, delete the processed entities and update the binary size
def reset_partial_binary(self):
self.binary_entities = []
self.binary_size = len(self.packed_header)
# Convert property keys from a CSV file header into a binary string
def pack_header(self, header):
prop_count = len(header) - self.prop_offset
# String format
fmt = "=%dsI" % (len(self.entity_str) + 1) # Unaligned native, entity_string, count of properties
args = [self.entity_str, prop_count]
for p in header[self.prop_offset:]:
prop = p.encode('utf-8')
fmt += "%ds" % (len(prop) + 1) # encode string with a null terminator
args.append(prop)
return struct.pack(fmt, *args)
# Convert a list of properties into a binary string
def pack_props(self, line):
props = []
for field in line[self.prop_offset:]:
props.append(prop_to_binary(field))
return b''.join(p for p in props)
def to_binary(self):
return self.packed_header + b''.join(self.binary_entities)
# Handler class for processing label csv files.
class Label(EntityFile):
def __init__(self, infile):
super(Label, self).__init__(infile)
expected_col_count = self.process_header()
self.process_entities(expected_col_count)
self.infile.close()
def process_header(self):
# Header format:
# node identifier (which may be a property key), then all other property keys
header = next(self.reader)
expected_col_count = len(header)
# If identifier field begins with an underscore, don't add it as a property.
if header[0][0] == '_':
self.prop_offset = 1
self.packed_header = self.pack_header(header)
self.binary_size += len(self.packed_header)
return expected_col_count
def process_entities(self, expected_col_count):
global NODE_DICT
global TOP_NODE_ID
global QUERY_BUF
entities_created = 0
with click.progressbar(self.reader, length=self.entities_count, label=self.entity_str) as reader:
for row in reader:
self.validate_row(expected_col_count, row)
# Add identifier->ID pair to dictionary if we are building relations
if NODE_DICT is not None:
if row[0] in NODE_DICT:
print("Node identifier '%s' was used multiple times - second occurrence at %s:%d"
% (row[0], self.infile.name, self.reader.line_num))
exit(1)
NODE_DICT[row[0]] = TOP_NODE_ID
TOP_NODE_ID += 1
row_binary = self.pack_props(row)
row_binary_len = len(row_binary)
# If the addition of this entity will make the binary token grow too large,
# send the buffer now.
if self.binary_size + row_binary_len > CONFIGS.max_token_size:
QUERY_BUF.labels.append(self.to_binary())
QUERY_BUF.send_buffer()
self.reset_partial_binary()
# Push the label onto the query buffer again, as there are more entities to process.
QUERY_BUF.labels.append(self.to_binary())
QUERY_BUF.node_count += 1
entities_created += 1
self.binary_size += row_binary_len
self.binary_entities.append(row_binary)
QUERY_BUF.labels.append(self.to_binary())
print("%d nodes created with label '%s'" % (entities_created, self.entity_str))
# Handler class for processing relation csv files.
class RelationType(EntityFile):
def __init__(self, infile):
super(RelationType, self).__init__(infile)
expected_col_count = self.process_header()
self.process_entities(expected_col_count)
self.infile.close()
def process_header(self):
# Header format:
# source identifier, dest identifier, properties[0..n]
header = next(self.reader)
# Assume rectangular CSVs
expected_col_count = len(header)
self.prop_count = expected_col_count - 2
if self.prop_count < 0:
raise CSVError("Relation file '%s' should have at least 2 elements in header line."
% (self.infile.name))
self.prop_offset = 2
self.packed_header = self.pack_header(header) # skip src and dest identifiers
self.binary_size += len(self.packed_header)
return expected_col_count
def process_entities(self, expected_col_count):
entities_created = 0
with click.progressbar(self.reader, length=self.entities_count, label=self.entity_str) as reader:
for row in reader:
self.validate_row(expected_col_count, row)
try:
src = NODE_DICT[row[0]]
dest = NODE_DICT[row[1]]
except KeyError as e:
print("Relationship specified a non-existent identifier.")
raise e
fmt = "=QQ" # 8-byte unsigned ints for src and dest
row_binary = struct.pack(fmt, src, dest) + self.pack_props(row)
row_binary_len = len(row_binary)
# If the addition of this entity will make the binary token grow too large,
# send the buffer now.
if self.binary_size + row_binary_len > CONFIGS.max_token_size:
QUERY_BUF.reltypes.append(self.to_binary())
QUERY_BUF.send_buffer()
self.reset_partial_binary()
# Push the reltype onto the query buffer again, as there are more entities to process.
QUERY_BUF.reltypes.append(self.to_binary())
QUERY_BUF.relation_count += 1
entities_created += 1
self.binary_size += row_binary_len
self.binary_entities.append(row_binary)
QUERY_BUF.reltypes.append(self.to_binary())
print("%d relations created for type '%s'" % (entities_created, self.entity_str))
# Convert a single CSV property field into a binary stream.
# Supported property types are string, numeric, boolean, and NULL.
def prop_to_binary(prop_str):
# All format strings start with an unsigned char to represent our Type enum
format_str = "=B"
if not prop_str:
# An empty field indicates a NULL property
return struct.pack(format_str, Type.NULL)
# If field can be cast to a float, allow it
try:
numeric_prop = float(prop_str)
return struct.pack(format_str + "d", Type.NUMERIC, numeric_prop)
except:
pass
# If field is 'false' or 'true', it is a boolean
if prop_str.lower() == 'false':
return struct.pack(format_str + '?', Type.BOOL, False)
elif prop_str.lower() == 'true':
return struct.pack(format_str + '?', Type.BOOL, True)
# If we've reached this point, the property is a string
# Encoding len+1 adds a null terminator to the string
encoded_str = prop_str.encode('utf-8')
format_str += "%ds" % (len(encoded_str) + 1)
return struct.pack(format_str, Type.STRING, encoded_str)
# For each node input file, validate contents and convert to binary format.
# If any buffer limits have been reached, flush all enqueued inserts to Redis.
def process_entity_csvs(cls, csvs):
global QUERY_BUF
for in_csv in csvs:
# Build entity descriptor from input CSV
entity = cls(in_csv)
added_size = entity.binary_size
# Check to see if the addition of this data will exceed the buffer's capacity
if (QUERY_BUF.buffer_size + added_size >= CONFIGS.max_buffer_size
or QUERY_BUF.redis_token_count + len(entity.binary_entities) >= CONFIGS.max_token_count):
# Send and flush the buffer if appropriate
QUERY_BUF.send_buffer()
# Add binary data to list and update all counts
QUERY_BUF.redis_token_count += len(entity.binary_entities)
QUERY_BUF.buffer_size += added_size
# Command-line arguments
@click.command()
@click.argument('graph')
# Redis server connection settings
@click.option('--host', '-h', default='127.0.0.1', help='Redis server host')
@click.option('--port', '-p', default=6379, help='Redis server port')
@click.option('--password', '-a', default=None, help='Redis server password')
# CSV file paths
@click.option('--nodes', '-n', required=True, multiple=True, help='Path to node csv file')
@click.option('--relations', '-r', multiple=True, help='Path to relation csv file')
# Buffer size restrictions
@click.option('--max-token-count', '-c', default=1024,
help='max number of processed CSVs to send per query (default 1024)')
@click.option('--max-buffer-size', '-b', default=2048, help='max buffer size in megabytes (default 2048)')
@click.option('--max-token-size', '-t', default=500, help='max size of each token in megabytes (default 500, max 512)')
def bulk_insert(graph, host, port, password, nodes, relations, max_token_count, max_buffer_size, max_token_size):
global CONFIGS
global NODE_DICT
global TOP_NODE_ID
global QUERY_BUF
TOP_NODE_ID = 0 # reset global ID variable (in case we are calling bulk_insert from unit tests)
CONFIGS = Configs(max_token_count, max_buffer_size, max_token_size)
start_time = timer()
# Attempt to connect to Redis server
try:
client = redis.StrictRedis(host=host, port=port, password=password)
except redis.exceptions.ConnectionError as e:
print("Could not connect to Redis server.")
raise e
# Attempt to verify that RedisGraph module is loaded
try:
module_list = client.execute_command("MODULE LIST")
if not any(b'graph' in module_description for module_description in module_list):
print("RedisGraph module not loaded on connected server.")
exit(1)
except redis.exceptions.ResponseError:
# Ignore check if the connected server does not support the "MODULE LIST" command
pass
QUERY_BUF = QueryBuffer(graph, client)
# Create a node dictionary if we're building relations and as such require unique identifiers
if relations:
NODE_DICT = {}
else:
NODE_DICT = None
process_entity_csvs(Label, nodes)
if relations:
process_entity_csvs(RelationType, relations)
# Send all remaining tokens to Redis
QUERY_BUF.send_buffer()
end_time = timer()
QUERY_BUF.report_completion(end_time - start_time)
if __name__ == '__main__':
bulk_insert()