diff --git a/chipsec/modules/tools/smm/smm_ptr.py b/chipsec/modules/tools/smm/smm_ptr.py index 6cb5ce2e9a..92441113f4 100644 --- a/chipsec/modules/tools/smm/smm_ptr.py +++ b/chipsec/modules/tools/smm/smm_ptr.py @@ -79,8 +79,8 @@ """ -import os import struct +import os import sys import time import math @@ -91,6 +91,7 @@ from chipsec.library.logger import print_buffer_bytes from chipsec.hal.interrupts import Interrupts from chipsec.library.exceptions import BadSMIDetected +from chipsec.helper.oshelper import OsHelper ################################################################# @@ -122,10 +123,10 @@ # Fuzz RCX as SMI subfunctions: from 0 to MAX_SMI_FUNCTIONS # False - better performance, True - smarter fuzzing FUZZ_SMI_FUNCTIONS_IN_ECX = True -MAX_SMI_FUNCTIONS = 0x100 +MAX_SMI_FUNCTIONS = 0x10 # Max value of the value written to SMI data port (0xB3) -MAX_SMI_DATA = 0x10 +MAX_SMI_DATA = 0x100 # Pass the pointer to SMI handlers in all general-purpose registers # rather than in one register @@ -136,7 +137,7 @@ # SMI handler may take a pointer/PA from (some offset of off) address passed in GPRs and write to it # Treat contents at physical address passed in GPRs as pointers and check contents at that pointer # If they changed, SMI handler might have modified them -#MODE_SECOND_ORDER_BUFFER = True +# MODE_SECOND_ORDER_BUFFER = True # Max offset of the pointer (physical address) # of the 2nd order buffer written in the memory buffer passed to SMI @@ -145,16 +146,14 @@ # very obscure option, don't even try to understand GPR_2ADDR = False -# Defines the time percentage increase at which the SMI call is considered to -# be long-running -OUTLIER_THRESHOLD = 10 +# Defines the threshold in standard deviations at which the SMI call is +# considered long-running OUTLIER_STD_DEV = 2 -SCAN_CALIB_SAMPLES = 50 -# Scan mode delay before SMI calls -SCAN_MODE_DELAY = 0.01 +# Number of samples used for initial calibration +SCAN_CALIB_SAMPLES = 50 -# MSR numbers +# SMI count MSR MSR_SMI_COUNT = 0x00000034 # @@ -197,47 +196,84 @@ def get_info(self): if self.code is None: return None else: - return f"duration {self.duration} code {self.code:02X} data {self.data:02X} ({gprs_info(self.gprs)})" + return f'duration {self.duration} code {self.code:02X} data {self.data:02X} ({gprs_info(self.gprs)})' -class scan_track: +class smi_stats: def __init__(self): self.clear() - self.hist_smi_duration = 0 - self.hist_smi_num = 0 - self.outliers_hist = 0 - self.records = {'deltas': [], 'times': []} - self.msr_count = self.get_msr_count() - self.stdev = 0 - self.stdev_hist = 0 + + def clear(self): + self.count = 0 + self.mean = 0 self.m2 = 0 - self.m2_hist = 0 - self.needs_calibration = True - self.calib_samples = 0 - self.first_measurement = True + self.stdev = 0 + self.outliers = 0 + + # + # Computes the standard deviation using the Welford's online algorithm + # + def update_stats(self, duration): + self.count += 1 + difference = duration - self.mean + self.mean += difference / self.count + self.m2 += difference * (duration - self.mean) + variance = self.m2 / self.count + self.stdev = math.sqrt(variance) + + def get_info(self): + info = f'average {round(self.mean)} stddev {self.stdev:.2f} checked {self.count}' + return info + + # + # Combines the statistics of the two data sets using parallel variance computation + # + def combine(self, partial): + self.outliers += partial.outliers + total_count = self.count + partial.count + difference = partial.mean - self.mean + self.mean = (self.mean * self.count + partial.mean * partial.count) / total_count + self.m2 += partial.m2 + difference**2 * self.count * partial.count / total_count + self.count = total_count + variance = self.m2 / self.count + self.stdev = math.sqrt(variance) - def get_msr_count(self): + +class scan_track: + def __init__(self): + self.current_smi_stats = smi_stats() + self.combined_smi_stats = smi_stats() + self.clear() + self.helper = OsHelper().get_default_helper() + self.helper.init() + self.smi_count = self.get_smi_count() + + def __del__(self): + self.helper.close() + + def get_smi_count(self): + count = -1 + # + # The SMI count is the same on all CPUs + # cpu = 0 - fd = os.open(f"/dev/cpu/{cpu}/msr", os.O_RDONLY) - os.lseek(fd, MSR_SMI_COUNT, os.SEEK_SET) - count = struct.unpack('Q', os.read(fd, 8))[0] - os.close(fd) + try: + count = self.helper.read_msr(cpu, MSR_SMI_COUNT) + count = count[1] << 32 | count[0] + except UnimplementedAPIError: + pass return count - def is_first_measurement(self): - is_first = self.first_measurement - if self.first_measurement: - self.first_measurement = False - return is_first - - def check_inc_msr(self): + def valid_smi_count(self): valid = False - count = self.get_msr_count() - if (count == self.msr_count + 1): + count = self.get_smi_count() + if (count == -1): + return True + elif (count == self.smi_count + 1): valid = True - self.msr_count = count + self.smi_count = count if not valid: - print("SMI contention detected", file=sys.stderr) + sys.stderr.write('SMI contention detected') return valid def find_address_in_regs(self, gprs): @@ -248,75 +284,51 @@ def find_address_in_regs(self, gprs): return key def clear(self): - self.max = smi_info(0) - self.min = smi_info(2**32-1) self.outlier = smi_info(0) - self.avg_smi_duration = 0 - self.avg_smi_num = 0 - self.outliers = 0 self.code = None - self.confirmed = False - self.records = {'deltas': [], 'times': []} - self.stdev = 0 - self.m2 = 0 + self.contents_changed = False self.needs_calibration = True self.calib_samples = 0 - self.first_measurement = True + self.current_smi_stats.clear() + self.records = {'deltas': [], 'times': []} - def add(self, duration, time, code, data, gprs, confirmed=False): + def add(self, duration, time, code, data, gprs, contents_changed=False): if not self.code: self.code = code outlier = self.is_outlier(duration) self.records['deltas'].append(duration) self.records['times'].append(time) - self.update_stdev(duration) if not outlier: - if duration > self.max.duration: - self.max.update(duration, code, data, gprs.copy()) - elif duration < self.min.duration: - self.min.update(duration, code, data, gprs.copy()) + self.current_smi_stats.update_stats(duration) elif self.is_slow_outlier(duration): - self.outliers += 1 - self.outliers_hist += 1 + self.current_smi_stats.outliers += 1 self.outlier.update(duration, code, data, gprs.copy()) - self.confirmed = confirmed - - def update_stdev(self, duration): - self.avg_smi_num += 1 - self.hist_smi_num += 1 - difference = duration - self.avg_smi_duration - difference_hist = duration - self.hist_smi_duration - self.avg_smi_duration += difference / self.avg_smi_num - self.hist_smi_duration += difference_hist / self.hist_smi_num - self.m2 += difference * (duration - self.avg_smi_duration) - self.m2_hist += difference_hist * (duration - self.hist_smi_duration) - variance = self.m2 / self.avg_smi_num - variance_hist = self.m2_hist / self.hist_smi_num - self.stdev = math.sqrt(variance) - self.stdev_hist = math.sqrt(variance_hist) + self.contents_changed = contents_changed def update_calibration(self, duration): if not self.needs_calibration: return - self.update_stdev(duration) + self.current_smi_stats.update_stats(duration) self.calib_samples += 1 if self.calib_samples >= SCAN_CALIB_SAMPLES: self.needs_calibration = False - print(f"Calibration done. stdev: {self.stdev}, mean: {self.avg_smi_duration}, samples: {self.calib_samples}") + print(f'Calibration done. stdev: {self.stdev}, mean: {self.avg_smi_duration}, samples: {self.calib_samples}') def is_slow_outlier(self, value): ret = False - if value > self.avg_smi_duration + OUTLIER_STD_DEV * self.stdev: + if value > self.current_smi_stats.mean + OUTLIER_STD_DEV * self.current_smi_stats.stdev: ret = True - if value > self.hist_smi_duration + OUTLIER_STD_DEV * self.stdev_hist: + if self.combined_smi_stats.count and \ + value > self.combined_smi_stats.mean + OUTLIER_STD_DEV * self.combined_smi_stats.stdev: ret = True return ret def is_fast_outlier(self, value): ret = False - if value < self.avg_smi_duration - OUTLIER_STD_DEV * self.stdev: + if value < self.current_smi_stats.mean - OUTLIER_STD_DEV * self.current_smi_stats.stdev: ret = True - if value < self.hist_smi_duration - OUTLIER_STD_DEV * self.stdev_hist: + if self.combined_smi_stats.count and \ + value < self.combined_smi_stats.mean - OUTLIER_STD_DEV * self.combined_smi_stats.stdev: ret = True return ret @@ -331,21 +343,20 @@ def is_outlier(self, value): return ret def skip(self): - #return self.outliers or self.confirmed + #return self.current_smi_stats.outliers or self.contents_changed return False def found_outlier(self): - return bool(self.outliers) + return bool(self.current_smi_stats.outliers) def get_total_outliers(self): - return self.outliers_hist + return self.combined_smi_stats.outliers def get_info(self): - avg = self.avg_smi_duration or self.hist_smi_duration - info = f"average {round(avg)} stdev {self.stdev} checked {self.avg_smi_num + self.outliers}" - if self.outliers: - info += f"\n Identified outlier: {self.outlier.get_info()}" - info += f"\nDeltas: {self.records}" + info = self.current_smi_stats.get_info() + if self.current_smi_stats.outliers: + info += f'\n Identified outlier: {self.outlier.get_info()}' + info += f'\nDeltas: {self.records}' return info def log_smi_result(self, logger): @@ -355,6 +366,9 @@ def log_smi_result(self, logger): else: logger.log(f'[*] {msg}') + def update_combined_stats(self): + self.combined_smi_stats.combine(self.current_smi_stats) + class smi_desc: def __init__(self): @@ -452,7 +466,7 @@ def check_memory(self, _addr, _smi_desc, fn, restore_contents=False): # Check if contents have changed at physical address passed in GPRs to SMI handler # If changed, SMI handler might have written to that address # - self.logger.log(" < Checking buffers") + self.logger.log(' < Checking buffers') expected_buf = FILL_BUFFER(self.fill_byte, self.fill_size, _smi_desc.ptr_in_buffer, _smi_desc.ptr, _smi_desc.ptr_offset, _smi_desc.sig, _smi_desc.sig_offset) buf = self.cs.mem.read_physical_mem(_addr, self.fill_size) @@ -526,11 +540,8 @@ def smi_fuzz_iter(self, thread_id, _addr, _smi_desc, fill_contents=True, restore self.send_smi(thread_id, _smi_desc.smi_code, _smi_desc.smi_data, _smi_desc.name, _smi_desc.desc, _rax, _rbx, _rcx, _rdx, _rsi, _rdi) else: while True: - #time.sleep(SCAN_MODE_DELAY) _, duration, start = self.send_smi_timed(thread_id, _smi_desc.smi_code, _smi_desc.smi_data, _smi_desc.name, _smi_desc.desc, _rax, _rbx, _rcx, _rdx, _rsi, _rdi) - #if scan.is_first_measurement(): - # continue - if not scan.check_inc_msr(): + if not scan.valid_smi_count(): continue if scan.needs_calibration: scan.update_calibration(duration) @@ -542,13 +553,8 @@ def smi_fuzz_iter(self, thread_id, _addr, _smi_desc, fill_contents=True, restore # if scan.is_outlier(duration): while True: - #print("Retrying...") - time.sleep(SCAN_MODE_DELAY) _, duration, start = self.send_smi_timed(thread_id, _smi_desc.smi_code, _smi_desc.smi_data, _smi_desc.name, _smi_desc.desc, _rax, _rbx, _rcx, _rdx, _rsi, _rdi) - if scan.is_outlier(duration): - print(f"Found outlier. Duration: {duration}, start: {start}") - #print(duration) - if scan.check_inc_msr(): + if scan.valid_smi_count(): break # # Check memory buffer if not in 'No Fill' mode @@ -622,9 +628,9 @@ def test_fuzz(self, thread_id, smic_start, smic_end, _addr, _addr1, scan_mode=Fa gprs_addr = {'rax': gpr_value, 'rbx': gpr_value, 'rcx': gpr_value, 'rdx': gpr_value, 'rsi': gpr_value, 'rdi': gpr_value} gprs_fill = {'rax': _FILL_VALUE_QWORD, 'rbx': _FILL_VALUE_QWORD, 'rcx': _FILL_VALUE_QWORD, 'rdx': _FILL_VALUE_QWORD, 'rsi': _FILL_VALUE_QWORD, 'rdi': _FILL_VALUE_QWORD} - self.logger.log("\n[*] >>> Fuzzing SMI handlers..") - self.logger.log("[*] AX in RAX will be overridden with values of SW SMI ports 0xB2/0xB3") - self.logger.log(" DX in RDX will be overridden with value 0x00B2") + self.logger.log('\n[*] >>> Fuzzing SMI handlers..') + self.logger.log('[*] AX in RAX will be overridden with values of SW SMI ports 0xB2/0xB3') + self.logger.log(' DX in RDX will be overridden with value 0x00B2') bad_ptr_cnt = 0 _smi_desc = smi_desc() @@ -711,21 +717,22 @@ def test_fuzz(self, thread_id, smic_start, smic_end, _addr, _addr1, scan_mode=Fa break if scan_mode: scan.log_smi_result(self.logger) + scan.update_combined_stats() scan.clear() return bad_ptr_cnt, scan def run(self, module_argv): - self.logger.start_test("A tool to test SMI handlers for pointer validation vulnerabilities") - self.logger.log("Usage: chipsec_main -m tools.smm.smm_ptr [ -a ,|,,
]") - self.logger.log(" mode SMI handlers testing mode") - self.logger.log(" = config use SMI configuration file ") - self.logger.log(" = fuzz fuzz all SMI handlers with code in the range ") - self.logger.log(" = fuzzmore fuzz mode + pass '2nd-order' pointers within buffer to SMI handlers") - self.logger.log(" = scan fuzz mode + time measurement to identify SMIs that trigger long-running code paths") - self.logger.log(" size size of the memory buffer (in Hex)") - self.logger.log(" address physical address of memory buffer to pass in GP regs to SMI handlers (in Hex)") - self.logger.log(" = smram pass address of SMRAM base (system may hang in this mode!)\n") + self.logger.start_test('A tool to test SMI handlers for pointer validation vulnerabilities') + self.logger.log('Usage: chipsec_main -m tools.smm.smm_ptr [ -a ,|,,
]') + self.logger.log(' mode SMI handlers testing mode') + self.logger.log(' = config use SMI configuration file ') + self.logger.log(' = fuzz fuzz all SMI handlers with code in the range ') + self.logger.log(' = fuzzmore fuzz mode + pass `2nd-order` pointers within buffer to SMI handlers') + self.logger.log(' = scan fuzz mode + time measurement to identify SMIs that trigger long-running code paths') + self.logger.log(' size size of the memory buffer (in Hex)') + self.logger.log(' address physical address of memory buffer to pass in GP regs to SMI handlers (in Hex)') + self.logger.log(' = smram pass address of SMRAM base (system may hang in this mode!)\n') test_mode = 'config' _smi_config_fname = 'chipsec/modules/tools/smm/smm_config.ini' @@ -772,9 +779,7 @@ def run(self, module_argv): (_, _addr1) = self.cs.mem.alloc_physical_mem(self.fill_size, _MAX_ALLOC_PA) self.logger.log(f'[*] Allocated 2nd buffer (address will be in the 1st buffer): 0x{_addr1:016X}') - # # @TODO: Need to check that SW/APMC SMI is enabled - # self.logger.log('\n[*] Configuration:') self.logger.log(f' SMI testing mode : {test_mode}') @@ -788,7 +793,7 @@ def run(self, module_argv): self.logger.log(f' Second buffer pointer : 0x{_addr1:016X} (address written to memory buffer)') self.logger.log(f' Number of bytes to fill : 0x{self.fill_size:X}') self.logger.log(f' Byte to fill with : 0x{ord(self.fill_byte):X}') - self.logger.log(f' Additional options (can be changed in the source code):f') + self.logger.log(' Additional options (can be changed in the source code):') self.logger.log(f' Fuzzing SMI functions in ECX? : {FUZZ_SMI_FUNCTIONS_IN_ECX:d}') self.logger.log(f' Max value of SMI function in ECX : 0x{MAX_SMI_FUNCTIONS:X}') self.logger.log(f' Max value of SMI data (B3) : 0x{MAX_SMI_DATA:X}') @@ -813,9 +818,9 @@ def run(self, module_argv): scan_mode = True scan = None bad_ptr_cnt, scan = self.test_fuzz(thread_id, smic_start, smic_end, _addr, _addr1, True) - except BadSMIDetected as msg: + except BadSMIDetected: bad_ptr_cnt = 1 - self.logger.log_important("Potentially bad SMI detected! Stopped fuzing (see FUZZ_BAIL_ON_1ST_DETECT option)") + self.logger.log_important('Potentially bad SMI detected! Stopped fuzing (see FUZZ_BAIL_ON_1ST_DETECT option)') if scan_mode and scan: self.logger.log_good(f'<<< Done: found {scan.get_total_outliers()} long-running SMIs') @@ -824,7 +829,7 @@ def run(self, module_argv): self.result.setStatusBit(self.result.status.POTENTIALLY_VULNERABLE) self.res = self.result.getReturnCode(ModuleResult.FAILED) else: - self.logger.log_good("<<< Done: didn't find unchecked input pointers in tested SMI handlers") + self.logger.log_good('<<< Done: did not find unchecked input pointers in tested SMI handlers') self.result.setStatusBit(self.result.status.SUCCESS) self.res = self.result.getReturnCode(ModuleResult.PASSED)