diff --git a/xls/modules/zstd/BUILD b/xls/modules/zstd/BUILD index e91298e554..52c4e61ed8 100644 --- a/xls/modules/zstd/BUILD +++ b/xls/modules/zstd/BUILD @@ -21,11 +21,9 @@ load( "//xls/build_rules:xls_build_defs.bzl", "xls_benchmark_ir", "xls_benchmark_verilog", - "xls_dslx_ir", "xls_dslx_library", "xls_dslx_test", "xls_dslx_verilog", - "xls_ir_opt_ir", ) package( @@ -774,3 +772,126 @@ place_and_route( synthesized_rtl = ":repacketizer_synth_asap7", target_die_utilization_percentage = "10", ) + +xls_dslx_library( + name = "zstd_dec_dslx", + srcs = [ + "zstd_dec.x", + ], + deps = [ + ":block_dec_dslx", + ":block_header_dslx", + ":buffer_dslx", + ":common_dslx", + ":frame_header_dslx", + ":frame_header_test_dslx", + ":magic_dslx", + ":ram_printer_dslx", + ":repacketizer_dslx", + ":sequence_executor_dslx", + "//xls/examples:ram_dslx", + ], +) + +xls_dslx_verilog( + name = "zstd_dec_verilog", + codegen_args = { + "module_name": "ZstdDecoder", + "generator": "pipeline", + "delay_model": "asap7", + "ram_configurations": ",".join([ + "{ram_name}:1RW:{req}:{resp}:{wr_comp}:{latency}".format( + latency = 5, + ram_name = "ram{}".format(num), + req = "zstd_dec__req{}_s".format(num), + resp = "zstd_dec__resp{}_r".format(num), + wr_comp = "zstd_dec__wr_comp{}_r".format(num), + ) + for num in range(7) + ]), + "pipeline_stages": "10", + "reset": "rst", + "reset_data_path": "true", + "reset_active_low": "false", + "reset_asynchronous": "true", + "flop_inputs": "false", + "flop_single_value_channels": "false", + "flop_outputs": "false", + "worst_case_throughput": "1", + "use_system_verilog": "false", + }, + dslx_top = "ZstdDecoder", + library = ":zstd_dec_dslx", + # TODO: 2024-01-15: Workaround for https://github.com/google/xls/issues/869 + # Force proc inlining for IR optimization + opt_ir_args = { + "inline_procs": "true", + }, + verilog_file = "zstd_dec.v", +) + +cc_test( + name = "zstd_dec_cc_test", + srcs = [ + "zstd_dec_test.cc", + ], + data = [ + ":zstd_dec_verilog.ir", + ], + #shard_count = 50, + deps = [ + ":data_generator", + "//xls/common:xls_gunit_main", + "//xls/common/file:filesystem", + "//xls/common/file:get_runfile_path", + "//xls/common/status:matchers", + "//xls/interpreter:interpreter_proc_runtime", + "//xls/ir:events", + "//xls/ir:ir_parser", + "//xls/ir:value", + "@com_github_facebook_zstd//:zstd", + "@com_google_googletest//:gtest", + ], +) + +xls_benchmark_ir( + name = "zstd_dec_opt_ir_benchmark", + src = ":zstd_dec_verilog.opt.ir", + benchmark_ir_args = { + #TODO: rewrite ram in opt_ir step to perform valid IR benchmark + "pipeline_stages": "1", + "delay_model": "asap7", + }, +) + +verilog_library( + name = "zstd_dec_verilog_lib", + srcs = [ + ":zstd_dec.v", + ], +) + +synthesize_rtl( + name = "zstd_dec_synth_asap7", + standard_cells = "@org_theopenroadproject_asap7sc7p5t_28//:asap7-sc7p5t_rev28_rvt", + top_module = "ZstdDecoder", + deps = [ + ":zstd_dec_verilog_lib", + ], +) + +benchmark_synth( + name = "zstd_dec_benchmark_synth", + synth_target = ":zstd_dec_synth_asap7", +) + +place_and_route( + name = "zstd_dec_place_and_route", + clock_period = "750", + core_padding_microns = 2, + min_pin_distance = "0.5", + placement_density = "0.30", + skip_detailed_routing = True, + synthesized_rtl = ":zstd_dec_synth_asap7", + target_die_utilization_percentage = "10", +) diff --git a/xls/modules/zstd/common.x b/xls/modules/zstd/common.x index e15b311ae0..873254cbdf 100644 --- a/xls/modules/zstd/common.x +++ b/xls/modules/zstd/common.x @@ -19,6 +19,7 @@ pub const BLOCK_SIZE_WIDTH = u32:21; pub const HISTORY_BUFFER_SIZE_KB = u32:64; pub const OFFSET_WIDTH = u32:22; pub const LENGTH_WIDTH = u32:22; +pub const BUFFER_WIDTH = u32:128; pub type BlockData = bits[DATA_WIDTH]; pub type BlockPacketLength = u32; diff --git a/xls/modules/zstd/data_generator.cc b/xls/modules/zstd/data_generator.cc index 00cbb6b91b..75f18552f9 100644 --- a/xls/modules/zstd/data_generator.cc +++ b/xls/modules/zstd/data_generator.cc @@ -107,4 +107,22 @@ absl::StatusOr> GenerateFrameHeader(int seed, bool magic) { return raw_data; } +absl::StatusOr> GenerateFrame(int seed, BlockType btype) { + std::vector args; + args.push_back("-s" + std::to_string(seed)); + std::filesystem::path output_path = + std::filesystem::temp_directory_path() / + std::filesystem::path( + CreateNameForGeneratedFile(absl::MakeSpan(args), ".zstd", "fh")); + args.push_back("-p" + std::string(output_path)); + if (btype != BlockType::RANDOM) + args.push_back("--block-type=" + std::to_string(btype)); + if (btype == BlockType::RLE) args.push_back("--content-size"); + + XLS_ASSIGN_OR_RETURN(auto result, CallDecodecorpus(args)); + auto raw_data = ReadFileAsRawData(output_path); + std::remove(output_path.c_str()); + return raw_data; +} + } // namespace xls::zstd diff --git a/xls/modules/zstd/data_generator.h b/xls/modules/zstd/data_generator.h index 06462f4872..feb7c14b83 100644 --- a/xls/modules/zstd/data_generator.h +++ b/xls/modules/zstd/data_generator.h @@ -35,7 +35,15 @@ namespace xls::zstd { +enum BlockType { + RAW, + RLE, + COMPRESSED, + RANDOM, +}; + absl::StatusOr> GenerateFrameHeader(int seed, bool magic); +absl::StatusOr> GenerateFrame(int seed, BlockType btype); } // namespace xls::zstd diff --git a/xls/modules/zstd/dec_demux.x b/xls/modules/zstd/dec_demux.x index 63c0547a59..fbf5eff816 100644 --- a/xls/modules/zstd/dec_demux.x +++ b/xls/modules/zstd/dec_demux.x @@ -109,6 +109,7 @@ pub proc DecoderDemux { )} next (tok: token, state: DecoderDemuxState) { + trace_fmt!("DecDemux: next: state: {:#x}", state); let (tok, data) = recv_if(tok, input_r, !state.last_packet.last, ZERO_DATA); let (send_raw, send_rle, send_cmp, new_state) = match state.status { DecoderDemuxStatus::IDLE => @@ -166,7 +167,7 @@ pub proc DecoderDemux { }; let tok = send_if(tok, rle_s, send_rle, rle_data); let tok = send_if(tok, cmp_s, send_cmp, data_to_send); - if (new_state.send_data == new_state.byte_to_pass) { + let end_state = if (new_state.send_data == new_state.byte_to_pass) { let next_id = if (state.last_packet.last && state.last_packet.last_block) { u32: 0 } else { @@ -181,7 +182,10 @@ pub proc DecoderDemux { } } else { new_state - } + }; + trace_fmt!("DecDemux: next: end_state: {:#x}", end_state); + + end_state } } diff --git a/xls/modules/zstd/dec_mux.x b/xls/modules/zstd/dec_mux.x index 7e000d264b..4a56cb5c6e 100644 --- a/xls/modules/zstd/dec_mux.x +++ b/xls/modules/zstd/dec_mux.x @@ -62,6 +62,7 @@ pub proc DecoderMux { ) {(raw_r, rle_r, cmp_r, output_s)} next (tok: token, state: DecoderMuxState) { + trace_fmt!("DecMux: next: state: {:#x}", state); let (tok, raw_data, raw_data_valid) = recv_if_non_blocking( tok, raw_r, !state.raw_data_valid, zero!()); let state = if (raw_data_valid) { @@ -159,6 +160,8 @@ pub proc DecoderMux { if (do_send) { trace_fmt!("sent {:#x}", data_to_send); } else {()}; + trace_fmt!("DecMux: next: end_state: {:#x}", state); + state } } diff --git a/xls/modules/zstd/rle_block_dec.x b/xls/modules/zstd/rle_block_dec.x index 941073fd21..3e4d414d72 100644 --- a/xls/modules/zstd/rle_block_dec.x +++ b/xls/modules/zstd/rle_block_dec.x @@ -186,6 +186,7 @@ proc BatchPacker { init { (BatchState { prev_last: true, ..ZERO_BATCH_STATE }) } next(tok: token, state: BatchState) { + trace_fmt!("BatchPacker: next: state: {:#x}", state); let (tok, decoded_data) = recv(tok, rle_data_r); let symbols_in_batch = state.symbols_in_batch; @@ -218,13 +219,16 @@ proc BatchPacker { let new_symbols_in_batch = if do_send_batch { BlockPacketLength:0 } else { symbols_in_batch }; let new_batch = if do_send_batch { BlockData:0 } else { batch }; - BatchState { + let end_state = BatchState { batch: new_batch, symbols_in_batch: new_symbols_in_batch, prev_last: decoded_data.last, prev_last_block: sync_data.last_block, prev_id: sync_data.id - } + }; + trace_fmt!("BatchPacker: next: end_state: {:#x}", end_state); + + end_state } } diff --git a/xls/modules/zstd/zstd_dec.x b/xls/modules/zstd/zstd_dec.x new file mode 100644 index 0000000000..a7ee7f5948 --- /dev/null +++ b/xls/modules/zstd/zstd_dec.x @@ -0,0 +1,356 @@ +// Copyright 2023 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains work-in-progress ZSTD decoder implementation +// More information about ZSTD decoding can be found in: +// https://datatracker.ietf.org/doc/html/rfc8878 + +import std; +import xls.modules.zstd.block_header; +import xls.modules.zstd.block_dec; +import xls.modules.zstd.sequence_executor; +import xls.modules.zstd.buffer as buff; +import xls.modules.zstd.common; +import xls.modules.zstd.frame_header; +import xls.modules.zstd.frame_header_test; +import xls.modules.zstd.magic; +import xls.modules.zstd.repacketizer; +import xls.examples.ram; + +type Buffer = buff::Buffer; +type BlockDataPacket = common::BlockDataPacket; +type BlockData = common::BlockData; +type BlockSize = common::BlockSize; +type SequenceExecutorPacket = common::SequenceExecutorPacket; +type ZstdDecodedPacket = common::ZstdDecodedPacket; + +// TODO: all of this porboably should be in common.x +const TEST_WINDOW_LOG_MAX_LIBZSTD = frame_header_test::TEST_WINDOW_LOG_MAX_LIBZSTD; +type RWRamReq = sequence_executor::RWRamReq; +type RWRamResp = sequence_executor::RWRamResp; + +const BUFFER_WIDTH = common::BUFFER_WIDTH; +const DATA_WIDTH = common::DATA_WIDTH; +const ZERO_FRAME_HEADER = frame_header::ZERO_FRAME_HEADER; +const ZERO_BLOCK_HEADER = block_header::ZERO_BLOCK_HEADER; + +enum ZstdDecoderStatus : u8 { + DECODE_MAGIC_NUMBER = 0, + DECODE_FRAME_HEADER = 1, + DECODE_BLOCK_HEADER = 2, + FEED_BLOCK_DECODER = 3, + DECODE_CHECKSUM = 4, + ERROR = 255, +} + +struct ZstdDecoderState { + status: ZstdDecoderStatus, + buffer: Buffer, + frame_header: frame_header::FrameHeader, + block_size_bytes: BlockSize, + last: bool, + bytes_sent: BlockSize, +} + +const ZERO_DECODER_STATE = zero!(); + +fn decode_magic_number(state: ZstdDecoderState) -> (bool, BlockDataPacket, ZstdDecoderState) { + trace_fmt!("zstd_dec: decode_magic_number: DECODING NEW FRAME"); + trace_fmt!("zstd_dec: decode_magic_number: state: {:#x}", state); + trace_fmt!("zstd_dec: decode_magic_number: Decoding magic number"); + let magic_result = magic::parse_magic_number(state.buffer); + trace_fmt!("zstd_dec: decode_magic_number: magic_result: {:#x}", magic_result); + let new_state = match magic_result.status { + magic::MagicStatus::OK => ZstdDecoderState { + status: ZstdDecoderStatus::DECODE_FRAME_HEADER, + buffer: magic_result.buffer, + ..state + }, + magic::MagicStatus::CORRUPTED => ZstdDecoderState { + status: ZstdDecoderStatus::ERROR, + ..ZERO_DECODER_STATE + }, + magic::MagicStatus::NO_ENOUGH_DATA => state, + _ => state, + }; + trace_fmt!("zstd_dec: decode_magic_number: new_state: {:#x}", new_state); + + (false, zero!(), new_state) +} + +fn decode_frame_header(state: ZstdDecoderState) -> (bool, BlockDataPacket, ZstdDecoderState) { + trace_fmt!("zstd_dec: decode_frame_header: DECODING FRAME HEADER"); + trace_fmt!("zstd_dec: decode_frame_header: state: {:#x}", state); + let frame_header_result = frame_header::parse_frame_header(state.buffer); + trace_fmt!("zstd_dec: decode_frame_header: frame_header_result: {:#x}", frame_header_result); + let new_state = match frame_header_result.status { + frame_header::FrameHeaderStatus::OK => ZstdDecoderState { + status: ZstdDecoderStatus::DECODE_BLOCK_HEADER, + buffer: frame_header_result.buffer, + frame_header: frame_header_result.header, + ..state + }, + frame_header::FrameHeaderStatus::CORRUPTED => ZstdDecoderState { + status: ZstdDecoderStatus::ERROR, + ..ZERO_DECODER_STATE + }, + frame_header::FrameHeaderStatus::NO_ENOUGH_DATA => state, + frame_header::FrameHeaderStatus::UNSUPPORTED_WINDOW_SIZE => ZstdDecoderState { + status: ZstdDecoderStatus::ERROR, + ..ZERO_DECODER_STATE + }, + _ => state, + }; + trace_fmt!("zstd_dec: decode_frame_header: new_state: {:#x}", new_state); + + (false, zero!(), new_state) +} + +fn decode_block_header(state: ZstdDecoderState) -> (bool, BlockDataPacket, ZstdDecoderState) { + trace_fmt!("zstd_dec: decode_block_header: DECODING BLOCK HEADER"); + trace_fmt!("zstd_dec: decode_block_header: state: {:#x}", state); + let block_header_result = block_header::parse_block_header(state.buffer); + trace_fmt!("zstd_dec: decode_block_header: block_header_result: {:#x}", block_header_result); + let new_state = match block_header_result.status { + block_header::BlockHeaderStatus::OK => { + trace_fmt!("zstd_dec: BlockHeader: {:#x}", block_header_result.header); + match block_header_result.header.btype { + common::BlockType::RAW => ZstdDecoderState { + status: ZstdDecoderStatus::FEED_BLOCK_DECODER, + buffer: state.buffer, + block_size_bytes: block_header_result.header.size as BlockSize + BlockSize:3, + last: block_header_result.header.last, + bytes_sent: BlockSize:0, + ..state + }, + common::BlockType::RLE => ZstdDecoderState { + status: ZstdDecoderStatus::FEED_BLOCK_DECODER, + buffer: state.buffer, + block_size_bytes: BlockSize:4, + last: block_header_result.header.last, + bytes_sent: BlockSize:0, + ..state + }, + common::BlockType::COMPRESSED => ZstdDecoderState { + status: ZstdDecoderStatus::FEED_BLOCK_DECODER, + buffer: state.buffer, + block_size_bytes: block_header_result.header.size as BlockSize + BlockSize:3, + last: block_header_result.header.last, + bytes_sent: BlockSize:0, + ..state + }, + _ => { + fail!("impossible_case", state) + } + } + }, + block_header::BlockHeaderStatus::CORRUPTED => ZstdDecoderState { + status: ZstdDecoderStatus::ERROR, + ..ZERO_DECODER_STATE + }, + block_header::BlockHeaderStatus::NO_ENOUGH_DATA => state, + _ => state, + }; + trace_fmt!("zstd_dec: decode_block_header: new_state: {:#x}", new_state); + + (false, zero!(), new_state) +} + +fn feed_block_decoder(state: ZstdDecoderState) -> (bool, BlockDataPacket, ZstdDecoderState) { + trace_fmt!("zstd_dec: feed_block_decoder: state: {:#x}", state); + let remaining_bytes_to_send = state.block_size_bytes - state.bytes_sent; + trace_fmt!("zstd_dec: feed_block_decoder: remaining_bytes_to_send: {}", remaining_bytes_to_send); + let buffer_length_bytes = state.buffer.length >> 3; + trace_fmt!("zstd_dec: feed_block_decoder: buffer_length_bytes: {}", buffer_length_bytes); + let data_width_bytes = (DATA_WIDTH >> 3) as BlockSize; + trace_fmt!("zstd_dec: feed_block_decoder: data_width_bytes: {}", data_width_bytes); + let remaining_bytes_to_send_now = std::umin(remaining_bytes_to_send, data_width_bytes); + trace_fmt!("zstd_dec: feed_block_decoder: remaining_bytes_to_send_now: {}", remaining_bytes_to_send_now); + if (buffer_length_bytes >= remaining_bytes_to_send_now as u32) { + let remaining_bits_to_send_now = (remaining_bytes_to_send_now as u32) << 3; + trace_fmt!("zstd_dec: feed_block_decoder: remaining_bits_to_send_now: {}", remaining_bits_to_send_now); + let last_packet = (remaining_bytes_to_send == remaining_bytes_to_send_now); + trace_fmt!("zstd_dec: feed_block_decoder: last_packet: {}", last_packet); + let (buffer_result, data_to_send) = buff::buffer_pop_checked(state.buffer, remaining_bits_to_send_now); + match buffer_result.status { + buff::BufferStatus::OK => { + let decoder_channel_data = BlockDataPacket { + last: last_packet, + last_block: state.last, + id: u32:0, + data: data_to_send[0: DATA_WIDTH as s32], + length: remaining_bits_to_send_now, + }; + let new_fsm_status = if (last_packet && state.last) { + ZstdDecoderStatus::DECODE_CHECKSUM + } else if (last_packet) { + ZstdDecoderStatus::DECODE_BLOCK_HEADER + } else { + ZstdDecoderStatus::FEED_BLOCK_DECODER + }; + trace_fmt!("zstd_dec: feed_block_decoder: packet to decode: {:#x}", decoder_channel_data); + let new_state = (true, decoder_channel_data, ZstdDecoderState { + bytes_sent: state.bytes_sent + remaining_bytes_to_send_now, + buffer: buffer_result.buffer, + status: new_fsm_status, + ..state + }); + trace_fmt!("zstd_dec: feed_block_decoder: new_state: {:#x}", new_state); + new_state + }, + _ => { + fail!("should_not_happen_1", (false, zero!(), state)) + } + } + } else { + trace_fmt!("zstd_dec: feed_block_decoder: Not enough data for intermediate FEED_BLOCK_DECODER block dump"); + (false, zero!(), state) + } +} + +fn decode_checksum(state: ZstdDecoderState) -> (bool, BlockDataPacket, ZstdDecoderState) { + trace_fmt!("zstd_dec: decode_checksum: state: {:#x}", state); + // Pop fixed checksum size of 4 bytes + let (buffer_result, _) = buff::buffer_pop_checked(state.buffer, u32:32); + + let new_state = ZstdDecoderState { + status: ZstdDecoderStatus::DECODE_MAGIC_NUMBER, + buffer: buffer_result.buffer, + ..state + }; + trace_fmt!("zstd_dec: decode_checksum: new_state: {:#x}", new_state); + + (false, zero!(), new_state) +} + +pub proc ZstdDecoder { + input_r: chan in; + block_dec_in_s: chan out; + output_s: chan out; + req0_s: chan out; + req1_s: chan out; + req2_s: chan out; + req3_s: chan out; + req4_s: chan out; + req5_s: chan out; + req6_s: chan out; + req7_s: chan out; + resp0_r: chan in; + resp1_r: chan in; + resp2_r: chan in; + resp3_r: chan in; + resp4_r: chan in; + resp5_r: chan in; + resp6_r: chan in; + resp7_r: chan in; + wr_comp0_r: chan<()> in; + wr_comp1_r: chan<()> in; + wr_comp2_r: chan<()> in; + wr_comp3_r: chan<()> in; + wr_comp4_r: chan<()> in; + wr_comp5_r: chan<()> in; + wr_comp6_r: chan<()> in; + wr_comp7_r: chan<()> in; + + init {(ZERO_DECODER_STATE)} + + config ( + input_r: chan in, + output_s: chan out, + req0_s: chan out, + req1_s: chan out, + req2_s: chan out, + req3_s: chan out, + req4_s: chan out, + req5_s: chan out, + req6_s: chan out, + req7_s: chan out, + resp0_r: chan in, + resp1_r: chan in, + resp2_r: chan in, + resp3_r: chan in, + resp4_r: chan in, + resp5_r: chan in, + resp6_r: chan in, + resp7_r: chan in, + wr_comp0_r: chan<()> in, + wr_comp1_r: chan<()> in, + wr_comp2_r: chan<()> in, + wr_comp3_r: chan<()> in, + wr_comp4_r: chan<()> in, + wr_comp5_r: chan<()> in, + wr_comp6_r: chan<()> in, + wr_comp7_r: chan<()> in + ) { + let (block_dec_in_s, block_dec_in_r) = chan; + let (seq_exec_in_s, seq_exec_in_r) = chan; + let (repacketizer_in_s, repacketizer_in_r) = chan; + + spawn block_dec::BlockDecoder(block_dec_in_r, seq_exec_in_s); + + spawn sequence_executor::SequenceExecutor( + seq_exec_in_r, repacketizer_in_s, + req0_s, req1_s, req2_s, req3_s, + req4_s, req5_s, req6_s, req7_s, + resp0_r, resp1_r, resp2_r, resp3_r, + resp4_r, resp5_r, resp6_r, resp7_r, + wr_comp0_r, wr_comp1_r, wr_comp2_r, wr_comp3_r, + wr_comp4_r, wr_comp5_r, wr_comp6_r, wr_comp7_r + ); + + spawn repacketizer::Repacketizer(repacketizer_in_r, output_s); + + (input_r, block_dec_in_s, output_s, + req0_s, req1_s, req2_s, req3_s, + req4_s, req5_s, req6_s, req7_s, + resp0_r, resp1_r, resp2_r, resp3_r, + resp4_r, resp5_r, resp6_r, resp7_r, + wr_comp0_r, wr_comp1_r, wr_comp2_r, wr_comp3_r, + wr_comp4_r, wr_comp5_r, wr_comp6_r, wr_comp7_r) + } + + next (tok: token, state: ZstdDecoderState) { + trace_fmt!("zstd_dec: next(): state: {:#x}", state); + let can_fit = buff::buffer_can_fit(state.buffer, BlockData:0); + trace_fmt!("zstd_dec: next(): can_fit: {}", can_fit); + let (tok, data) = recv_if(tok, input_r, can_fit, BlockData:0); + let state = if (can_fit) { + let buffer = buff::buffer_append(state.buffer, data); + trace_fmt!("zstd_dec: next(): received more data: {:#x}", data); + ZstdDecoderState {buffer, ..state} + } else { + state + }; + trace_fmt!("zstd_dec: next(): state after receive: {:#x}", state); + + let (do_send, data_to_send, state) = match state.status { + ZstdDecoderStatus::DECODE_MAGIC_NUMBER => + decode_magic_number(state), + ZstdDecoderStatus::DECODE_FRAME_HEADER => + decode_frame_header(state), + ZstdDecoderStatus::DECODE_BLOCK_HEADER => + decode_block_header(state), + ZstdDecoderStatus::FEED_BLOCK_DECODER => + feed_block_decoder(state), + ZstdDecoderStatus::DECODE_CHECKSUM => + decode_checksum(state), + _ => (false, zero!(), state) + }; + + trace_fmt!("zstd_dec: next(): do_send: {:#x}, data_to_send: {:#x}, state: {:#x}", do_send, data_to_send, state); + let tok = send_if(tok, block_dec_in_s, do_send, data_to_send); + + state + } +} diff --git a/xls/modules/zstd/zstd_dec_test.cc b/xls/modules/zstd/zstd_dec_test.cc new file mode 100644 index 0000000000..b1388b6553 --- /dev/null +++ b/xls/modules/zstd/zstd_dec_test.cc @@ -0,0 +1,241 @@ +// Copyright 2020 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include + +#include "gtest/gtest.h" +#include "xls/common/file/filesystem.h" +#include "xls/common/file/get_runfile_path.h" +#include "xls/common/status/matchers.h" +#include "xls/interpreter/interpreter_proc_runtime.h" +#include "xls/interpreter/serial_proc_runtime.h" +#include "xls/ir/events.h" +#include "xls/ir/ir_parser.h" +#include "xls/modules/zstd/data_generator.h" +#include "zstd.h" + +namespace xls { +namespace { + +class ZstdDecodedPacket { + public: + static std::optional MakeZstdDecodedPacket(Value packet) { + // Expect tuple + if (!packet.IsTuple()) return std::nullopt; + // Expect exactly 3 fields + if (packet.size() != 3) return std::nullopt; + for (int i = 0; i < 3; i++) { + // Expect fields to be Bits + if (!packet.element(i).IsBits()) return std::nullopt; + // All fields must fit in 64bits + if (!packet.element(i).bits().FitsInUint64()) return std::nullopt; + } + + std::vector data = packet.element(0).bits().ToBytes(); + absl::StatusOr len = packet.element(1).bits().ToUint64(); + if (!len.ok()) return std::nullopt; + uint64_t length = *len; + bool last = packet.element(2).bits().IsOne(); + + return ZstdDecodedPacket(data, length, last); + } + + std::vector GetData() { return data; } + + uint64_t GetLength() { return length; } + + bool IsLast() { return last; } + + const std::string PrintData() const { + std::stringstream s; + for (int j = 0; j < sizeof(uint64_t) && j < data.size(); j++) { + s << "0x" << std::setw(2) << std::setfill('0') << std::right << std::hex + << (unsigned int)data[j] << std::dec << ", "; + } + return s.str(); + } + + friend std::ostream& operator<<(std::ostream& os, + const ZstdDecodedPacket& packet) { + return os << "ZstdDecodedPacket { data: {" << packet.PrintData() + << "}, length: " << packet.length << " last: " << packet.last + << "}" << std::endl; + } + + private: + ZstdDecodedPacket(std::vector data, uint64_t length, bool last) + : data(data), length(length), last(last) {} + + std::vector data; + uint64_t length; + bool last; +}; + +class ZstdDecoderTest : public ::testing::Test { + public: + void SetUp() { + XLS_ASSERT_OK_AND_ASSIGN(std::filesystem::path ir_path, + xls::GetXlsRunfilePath(this->ir_file)); + XLS_ASSERT_OK_AND_ASSIGN(std::string ir_text, + xls::GetFileContents(ir_path)); + XLS_ASSERT_OK_AND_ASSIGN(this->package, xls::Parser::ParsePackage(ir_text)); + XLS_ASSERT_OK_AND_ASSIGN( + this->interpreter, + CreateInterpreterSerialProcRuntime(this->package.get())); + + auto& queue_manager = this->interpreter->queue_manager(); + XLS_ASSERT_OK_AND_ASSIGN(this->recv_queue, queue_manager.GetQueueByName( + this->recv_channel_name)); + XLS_ASSERT_OK_AND_ASSIGN(this->send_queue, queue_manager.GetQueueByName( + this->send_channel_name)); + } + + void PrintTraceMessages(std::string pname) { + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, this->package->GetProc(pname)); + const InterpreterEvents& events = + this->interpreter->GetInterpreterEvents(proc); + + if (!events.trace_msgs.empty()) { + for (const auto& tm : events.trace_msgs) { + std::cout << "[TRACE] " << tm << std::endl; + } + } + } + + const char* proc_name = "__zstd_dec__ZstdDecoder_0_next"; + const char* recv_channel_name = "zstd_dec__output_s"; + const char* send_channel_name = "zstd_dec__input_r"; + + const char* ir_file = "xls/modules/zstd/zstd_dec_verilog.ir"; + + std::unique_ptr package; + std::unique_ptr interpreter; + ChannelQueue *recv_queue, *send_queue; + + void PrintVector(absl::Span vec) { + for (int i = 0; i < vec.size(); i += 8) { + std::cout << "0x" << std::hex << std::setw(3) << std::left << i + << std::dec << ": "; + for (int j = 0; j < sizeof(uint64_t) && (i + j) < vec.size(); j++) { + std::cout << std::setw(2) << std::hex << (unsigned int)vec[i + j] + << std::dec << " "; + } + std::cout << std::endl; + } + } + + void ParseAndCompareWithZstd(std::vector frame) { + size_t decompressed_size = + ZSTD_getFrameContentSize(frame.data(), frame.size()); + EXPECT_FALSE(ZSTD_isError(decompressed_size)); + + // Decompress the frame with libzstd + uint8_t* lib_decomp = new uint8_t[decompressed_size]; + size_t lib_decomp_size = ZSTD_decompress(lib_decomp, decompressed_size, + frame.data(), frame.size()); + EXPECT_FALSE(ZSTD_isError(lib_decomp_size)); + + std::vector sim_decomp; + size_t sim_decomp_size_words = + (decompressed_size + sizeof(uint64_t) - 1) / sizeof(uint64_t); + size_t sim_decomp_size_bytes = + (decompressed_size + sizeof(uint64_t) - 1) * sizeof(uint64_t); + sim_decomp.reserve(sim_decomp_size_bytes); + + // Send compressed frame to decoder simulation + for (int i = 0; i < frame.size(); i += 8) { + auto span = absl::MakeSpan(frame.data() + i, 8); + auto value = Value(Bits::FromBytes(span, 64)); + XLS_EXPECT_OK(this->send_queue->Write(value)); + XLS_EXPECT_OK(this->interpreter->Tick()); + } + PrintTraceMessages("__zstd_dec__ZstdDecoder_0_next"); + + // Tick decoder simulation until we get expected amount of output data + // batches on output channel queue + std::optional ticks_timeout = std::nullopt; + absl::flat_hash_map output_counts = { + {this->recv_queue->channel(), sim_decomp_size_words}}; + XLS_EXPECT_OK( + this->interpreter->TickUntilOutput(output_counts, ticks_timeout)); + + // Read decompressed data from output channel queue + for (int i = 0; i < sim_decomp_size_words; i++) { + auto read_value = this->recv_queue->Read(); + EXPECT_EQ(read_value.has_value(), true); + auto packet = + ZstdDecodedPacket::MakeZstdDecodedPacket(read_value.value()); + EXPECT_EQ(packet.has_value(), true); + auto word_vec = packet->GetData(); + auto valid_length = packet->GetLength() / CHAR_BIT; + std::copy(begin(word_vec), begin(word_vec) + valid_length, + back_inserter(sim_decomp)); + } + + EXPECT_EQ(lib_decomp_size, sim_decomp.size()); + for (int i = 0; i < lib_decomp_size; i++) { + EXPECT_EQ(lib_decomp[i], sim_decomp[i]); + } + } +}; + +/* TESTS */ + +TEST(ZstdLib, Version) { ASSERT_EQ(ZSTD_VERSION_STRING, "1.4.7"); } + +TEST_F(ZstdDecoderTest, ParseFrameWithRawBlocks) { + int seed = 3; // Arbitrary seed value for small ZSTD frame + auto frame = zstd::GenerateFrame(seed, zstd::BlockType::RAW); + EXPECT_TRUE(frame.ok()); + this->ParseAndCompareWithZstd(frame.value()); +} + +TEST_F(ZstdDecoderTest, ParseFrameWithRleBlocks) { + int seed = 3; // Arbitrary seed value for small ZSTD frame + auto frame = zstd::GenerateFrame(seed, zstd::BlockType::RLE); + EXPECT_TRUE(frame.ok()); + this->ParseAndCompareWithZstd(frame.value()); +} + +class ZstdDecoderSeededTest : public ZstdDecoderTest, + public ::testing::WithParamInterface { + public: + static const uint32_t random_frames_count = 50; +}; + +// Test `random_frames_count` instances of randomly generated valid +// frames, generated with `decodecorpus` tool. + +TEST_P(ZstdDecoderSeededTest, ParseMultipleFramesWithRawBlocks) { + auto seed = GetParam(); + auto frame = zstd::GenerateFrame(seed, zstd::BlockType::RAW); + EXPECT_TRUE(frame.ok()); + this->ParseAndCompareWithZstd(frame.value()); +} + +TEST_P(ZstdDecoderSeededTest, ParseMultipleFramesWithRleBlocks) { + auto seed = GetParam(); + auto frame = zstd::GenerateFrame(seed, zstd::BlockType::RLE); + EXPECT_TRUE(frame.ok()); + this->ParseAndCompareWithZstd(frame.value()); +} + +INSTANTIATE_TEST_SUITE_P( + ZstdDecoderSeededTest, ZstdDecoderSeededTest, + ::testing::Range(0, ZstdDecoderSeededTest::random_frames_count)); + +} // namespace +} // namespace xls