@@ -202,6 +202,47 @@ bool Sample::ArgsBatchEqual(const Sample& other) const {
202
202
return true ;
203
203
}
204
204
205
+ // Extract args batch from SampleInputsProto. If to be interpreted as
206
+ // proc_samples, also extract "ir_channel_names" (which must not be a nullptr
207
+ // then).
208
+ /* static */ absl::Status Sample::ExtractArgsBatch (
209
+ bool is_proc_samples, const testvector::SampleInputsProto& testvector,
210
+ std::vector<std::vector<InterpValue>>& args_batch,
211
+ std::vector<std::string>* ir_channel_names) {
212
+ // In the serialization channel inputs are grouped by channel, but the
213
+ // fuzzer expects inputs to be grouped by input number.
214
+ // TODO(meheff): Change the fuzzer to accept inputs grouped by channel. This
215
+ // would enable a different number of inputs per channel.
216
+ if (is_proc_samples) {
217
+ XLS_RET_CHECK (!testvector.has_function_args ()); // proc samples expected
218
+ XLS_RET_CHECK (ir_channel_names != nullptr );
219
+ for (const testvector::ChannelInputProto& channel_input :
220
+ testvector.channel_inputs ().inputs ()) {
221
+ ir_channel_names->push_back (channel_input.channel_name ());
222
+ for (int i = 0 ; i < channel_input.values ().size (); ++i) {
223
+ const std::string& value_str = channel_input.values (i);
224
+ XLS_ASSIGN_OR_RETURN (Value value, Parser::ParseTypedValue (value_str));
225
+ XLS_ASSIGN_OR_RETURN (InterpValue interp_value,
226
+ dslx::ValueToInterpValue (value));
227
+ if (args_batch.size () <= i) {
228
+ args_batch.resize (i + 1 );
229
+ }
230
+ args_batch[i].push_back (interp_value);
231
+ }
232
+ }
233
+ return absl::OkStatus ();
234
+ }
235
+
236
+ // Otherwise just extract function information.
237
+ XLS_RET_CHECK (!testvector.has_channel_inputs ()); // function samples expected
238
+ for (const std::string& arg : testvector.function_args ().args ()) {
239
+ XLS_ASSIGN_OR_RETURN (std::vector<InterpValue> args, dslx::ParseArgs (arg));
240
+ args_batch.push_back (args);
241
+ }
242
+
243
+ return absl::OkStatus ();
244
+ }
245
+
205
246
/* static */ absl::StatusOr<Sample> Sample::Deserialize (std::string_view s) {
206
247
bool in_config = false ;
207
248
std::vector<std::string_view> config_lines;
@@ -238,37 +279,39 @@ bool Sample::ArgsBatchEqual(const Sample& other) const {
238
279
XLS_ASSIGN_OR_RETURN (SampleOptions options,
239
280
SampleOptions::FromProto (proto.sample_options ()));
240
281
241
- std::string dslx_code = absl::StrJoin (dslx_lines, " \n " );
282
+ // Make sure we see the kind of inputs we expect.
283
+ XLS_RET_CHECK_EQ (proto.inputs ().has_function_args (),
284
+ options.IsFunctionSample ());
242
285
243
- // In the serialization channel inputs are grouped by channel, but the
244
- // fuzzer expects inputs to be grouped by input number.
245
- // TODO(meheff): Change the fuzzer to accept inputs grouped by channel. This
246
- // would enable a different number of inputs per channel.
247
- std::vector<std::string> ir_channel_names;
248
286
std::vector<std::vector<InterpValue>> args_batch;
249
- if (proto. sample_options (). sample_type () == fuzzer::SAMPLE_TYPE_PROC) {
250
- for ( const testvector::ChannelInputProto& channel_input :
251
- proto. inputs (). channel_inputs (). inputs ()) {
252
- ir_channel_names. push_back (channel_input. channel_name ());
253
- for ( int i = 0 ; i < channel_input. values (). size (); ++i) {
254
- const std::string& value_str = channel_input. values (i );
255
- XLS_ASSIGN_OR_RETURN (Value value, Parser::ParseTypedValue (value_str));
256
- XLS_ASSIGN_OR_RETURN (InterpValue interp_value,
257
- dslx::ValueToInterpValue (value));
258
- if (args_batch. size () <= i) {
259
- args_batch. resize (i + 1 );
260
- }
261
- args_batch[i]. push_back (interp_value);
262
- }
287
+ std::vector<std::string> ir_channel_names;
288
+ XLS_RETURN_IF_ERROR ( ExtractArgsBatch (options. IsProcSample (), proto. inputs (),
289
+ args_batch, &ir_channel_names));
290
+
291
+ std::string dslx_code = absl::StrJoin (dslx_lines, " \n " );
292
+ return Sample (dslx_code, options, args_batch, ir_channel_names );
293
+ }
294
+
295
+ absl::Status Sample::FillSampleInputs (
296
+ testvector::SampleInputsProto* proto) const {
297
+ if ( options (). IsFunctionSample ()) {
298
+ testvector::FunctionArgsProto* args_proto = proto-> mutable_function_args ();
299
+ for ( const std::vector<InterpValue>& args : args_batch_) {
300
+ args_proto-> add_args ( InterpValueListToString (args));
263
301
}
264
302
} else {
265
- XLS_RET_CHECK (proto.inputs ().has_function_args ());
266
- for (const std::string& arg : proto.inputs ().function_args ().args ()) {
267
- XLS_ASSIGN_OR_RETURN (std::vector<InterpValue> args, dslx::ParseArgs (arg));
268
- args_batch.push_back (args);
303
+ XLS_RET_CHECK (options ().IsProcSample ());
304
+ testvector::ChannelInputsProto* inputs_proto =
305
+ proto->mutable_channel_inputs ();
306
+ for (int64_t i = 0 ; i < ir_channel_names_.size (); ++i) {
307
+ testvector::ChannelInputProto* input_proto = inputs_proto->add_inputs ();
308
+ input_proto->set_channel_name (ir_channel_names_[i]);
309
+ for (const std::vector<InterpValue>& args : args_batch_) {
310
+ input_proto->add_values (ToArgString (args[i]));
311
+ }
269
312
}
270
313
}
271
- return Sample (dslx_code, options, args_batch, ir_channel_names );
314
+ return absl::OkStatus ( );
272
315
}
273
316
274
317
std::string Sample::Serialize (
@@ -285,24 +328,8 @@ std::string Sample::Serialize(
285
328
config.set_issue (std::string (" DO NOT " ) +
286
329
" SUBMIT Insert link to GitHub issue here." );
287
330
*config.mutable_sample_options () = options ().proto ();
288
- if (options ().IsFunctionSample ()) {
289
- testvector::FunctionArgsProto* args_proto =
290
- config.mutable_inputs ()->mutable_function_args ();
291
- for (const std::vector<InterpValue>& args : args_batch_) {
292
- args_proto->add_args (InterpValueListToString (args));
293
- }
294
- } else {
295
- CHECK (options ().IsProcSample ());
296
- testvector::ChannelInputsProto* inputs_proto =
297
- config.mutable_inputs ()->mutable_channel_inputs ();
298
- for (int64_t i = 0 ; i < ir_channel_names_.size (); ++i) {
299
- testvector::ChannelInputProto* input_proto = inputs_proto->add_inputs ();
300
- input_proto->set_channel_name (ir_channel_names_[i]);
301
- for (const std::vector<InterpValue>& args : args_batch_) {
302
- input_proto->add_values (ToArgString (args[i]));
303
- }
304
- }
305
- }
331
+ CHECK_OK (FillSampleInputs (config.mutable_inputs ()));
332
+
306
333
std::string config_text;
307
334
CHECK (google::protobuf::TextFormat::PrintToString (config, &config_text));
308
335
for (std::string_view line : absl::StrSplit (config_text, ' \n ' )) {
0 commit comments