Skip to content

Commit a76fcb1

Browse files
committed
Find proper types for ComStmtBulkExecuteRequest
1 parent 88adbf9 commit a76fcb1

1 file changed

Lines changed: 35 additions & 9 deletions

File tree

src/packets/mod.rs

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3179,28 +3179,54 @@ impl<'a> ComStmtBulkExecuteRequest<'a> {
31793179
let first = values.first().ok_or(BulkExecuteRequestError::NoParams)?;
31803180
let arity = first.len();
31813181

3182-
let types = if bulk_flags.contains(StmtBulkExecuteFlags::SEND_TYPES_TO_SERVER) {
3183-
Seq::new(
3184-
first
3185-
.iter()
3186-
.map(StmtBulkExecuteParamType::from_value)
3187-
.collect::<Vec<_>>(),
3188-
)
3182+
let mut types = if bulk_flags.contains(StmtBulkExecuteFlags::SEND_TYPES_TO_SERVER) {
3183+
first
3184+
.iter()
3185+
.map(StmtBulkExecuteParamType::from_value)
3186+
.collect::<Vec<_>>()
31893187
} else {
3190-
Seq::empty()
3188+
Vec::default()
31913189
};
31923190

31933191
for values in &values {
31943192
if values.len() != arity {
31953193
return Err(BulkExecuteRequestError::MixedArity);
31963194
}
3195+
let row_types = values
3196+
.iter()
3197+
.map(StmtBulkExecuteParamType::from_value)
3198+
.collect::<Vec<_>>();
3199+
3200+
// The point here is to find proper type for every param, i.e the type
3201+
// that covers param values in all rows.
3202+
// E.g. ColumnType::MYSQL_TYPE_NULL is a proper type only if all the
3203+
// values for the given param are NULLs
3204+
for (left, right) in types.iter_mut().zip(&row_types) {
3205+
if left != right {
3206+
if left.column_type() == ColumnType::MYSQL_TYPE_NULL {
3207+
*left = *right;
3208+
} else if left.column_type() == right.column_type() {
3209+
*left = StmtBulkExecuteParamType::new(
3210+
right.column_type(),
3211+
// if flag is required by a single param value,
3212+
// then it must be given for the whole batch
3213+
left.flags().union(right.flags()),
3214+
)
3215+
} else {
3216+
// TODO: Values of different types are given for the same parameter
3217+
// within the batch. Not sure if server will always error here.
3218+
// The error:
3219+
// ERROR 1210 (HY000): Incorrect arguments to mysqld_stmt_bulk_execute
3220+
}
3221+
}
3222+
}
31973223
}
31983224

31993225
Ok(Self {
32003226
header: ConstU8::new(),
32013227
stmt_id: RawInt::new(stmt_id),
32023228
bulk_flags: Const::new(bulk_flags),
3203-
types,
3229+
types: Seq::new(types),
32043230
values: StmtBulkExecuteParamValues::new(values.into_iter().flatten()),
32053231
})
32063232
}

0 commit comments

Comments
 (0)