Skip to content

Commit 3a6ff04

Browse files
authored
Merge pull request #622 from splunk/SDK-23/fix-map-reporting-command
SDK-23: check if developer added custom map method
2 parents ce732fc + 42bf5e1 commit 3a6ff04

File tree

3 files changed

+53
-9
lines changed

3 files changed

+53
-9
lines changed

splunklib/searchcommands/internals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def write_record(self, record):
554554

555555
def write_records(self, records):
556556
self._ensure_validity()
557-
records = list(records)
557+
records = [] if records is NotImplemented else list(records)
558558
write_record = self._write_record
559559
for record in records:
560560
write_record(record)

splunklib/searchcommands/reporting_command.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,21 +77,26 @@ def map(self, records):
7777
"""
7878
return NotImplemented
7979

80-
def prepare(self):
81-
82-
phase = self.phase
80+
def _has_custom_method(self, method_name):
81+
method = getattr(self.__class__, method_name, None)
82+
base_method = getattr(ReportingCommand, method_name, None)
83+
return callable(method) and (method is not base_method)
8384

84-
if phase == 'map':
85-
# noinspection PyUnresolvedReferences
86-
self._configuration = self.map.ConfigurationSettings(self)
85+
def prepare(self):
86+
if self.phase == 'map':
87+
if self._has_custom_method('map'):
88+
phase_method = getattr(self.__class__, 'map')
89+
self._configuration = phase_method.ConfigurationSettings(self)
90+
else:
91+
self._configuration = self.ConfigurationSettings(self)
8792
return
8893

89-
if phase == 'reduce':
94+
if self.phase == 'reduce':
9095
streaming_preop = chain((self.name, 'phase="map"', str(self._options)), self.fieldnames)
9196
self._configuration.streaming_preop = ' '.join(streaming_preop)
9297
return
9398

94-
raise RuntimeError(f'Unrecognized reporting command phase: {json_encode_string(str(phase))}')
99+
raise RuntimeError(f'Unrecognized reporting command phase: {json_encode_string(str(self.phase))}')
95100

96101
def reduce(self, records):
97102
""" Override this method to produce a reporting data structure.

tests/searchcommands/test_reporting_command.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,42 @@ def reduce(self, records):
3232
data = list(data_chunk.data)
3333
assert len(data) == 1
3434
assert int(data[0]['sum']) == sum(range(0, 10))
35+
36+
37+
def test_simple_reporting_command_with_map():
38+
@searchcommands.Configuration()
39+
class MapAndReduceReportingCommand(searchcommands.ReportingCommand):
40+
def map(self, records):
41+
for record in records:
42+
record["value"] = str(int(record["value"]) * 2)
43+
yield record
44+
45+
def reduce(self, records):
46+
total = 0
47+
for record in records:
48+
total += int(record["value"])
49+
yield {"sum": total}
50+
51+
cmd = MapAndReduceReportingCommand()
52+
ifile = io.BytesIO()
53+
54+
input_data = [{"value": str(i)} for i in range(5)]
55+
56+
mapped_data = list(cmd.map(input_data))
57+
58+
ifile.write(chunky.build_getinfo_chunk())
59+
ifile.write(chunky.build_data_chunk(mapped_data))
60+
ifile.seek(0)
61+
62+
ofile = io.BytesIO()
63+
cmd._process_protocol_v2([], ifile, ofile)
64+
65+
ofile.seek(0)
66+
chunk_stream = chunky.ChunkedDataStream(ofile)
67+
chunk_stream.read_chunk()
68+
data_chunk = chunk_stream.read_chunk()
69+
assert data_chunk.meta['finished'] is True
70+
71+
result = list(data_chunk.data)
72+
expected_sum = sum(i * 2 for i in range(5))
73+
assert int(result[0]["sum"]) == expected_sum

0 commit comments

Comments
 (0)