Skip to content

Commit 74b1b37

Browse files
committed
SDK-23: check if developer added custom map method
1 parent d36db5e commit 74b1b37

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

splunklib/searchcommands/internals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def write_record(self, record):
554554

555555
def write_records(self, records):
556556
self._ensure_validity()
557-
records = list(records)
557+
records = [] if records is NotImplemented else list(records)
558558
write_record = self._write_record
559559
for record in records:
560560
write_record(record)

splunklib/searchcommands/reporting_command.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,21 +77,26 @@ def map(self, records):
7777
"""
7878
return NotImplemented
7979

80-
def prepare(self):
81-
82-
phase = self.phase
80+
def _has_custom_method(self, method_name):
81+
method = getattr(self.__class__, method_name, None)
82+
base_method = getattr(ReportingCommand, method_name, None)
83+
return callable(method) and (method is not base_method)
8384

84-
if phase == 'map':
85-
# noinspection PyUnresolvedReferences
86-
self._configuration = self.map.ConfigurationSettings(self)
85+
def prepare(self):
86+
if self.phase == 'map':
87+
if self._has_custom_method('map'):
88+
phase_method = getattr(self.__class__, 'map')
89+
self._configuration = phase_method.ConfigurationSettings(self)
90+
else:
91+
self._configuration = self.ConfigurationSettings(self)
8792
return
8893

89-
if phase == 'reduce':
94+
if self.phase == 'reduce':
9095
streaming_preop = chain((self.name, 'phase="map"', str(self._options)), self.fieldnames)
9196
self._configuration.streaming_preop = ' '.join(streaming_preop)
9297
return
9398

94-
raise RuntimeError(f'Unrecognized reporting command phase: {json_encode_string(str(phase))}')
99+
raise RuntimeError(f'Unrecognized reporting command phase: {json_encode_string(str(self.phase))}')
95100

96101
def reduce(self, records):
97102
""" Override this method to produce a reporting data structure.

0 commit comments

Comments
 (0)