Skip to content

Commit 5dca220

Browse files
committed
Get started on bfloat16 support
1 parent aee12e0 commit 5dca220

3 files changed

Lines changed: 58 additions & 0 deletions

File tree

R/codecs_array_bytes.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@ codec_bytes_decode <- function(
55
datatype,
66
endian
77
) {
8+
if (!is.list(datatype) && datatype$base_type == "bfloat") {
9+
return(.Call(
10+
"type_convert_bfloat",
11+
input,
12+
datatype$nbytes,
13+
chunk_dim,
14+
endian,
15+
PACKAGE = "Rarr"
16+
))
17+
}
818
convert_bytes_to_array(
919
input,
1020
datatype$base_type,

src/type_conversion.c

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include "type_conversion.h"
2+
3+
SEXP type_convert_bfloat(SEXP input, SEXP _n_bytes, SEXP dims, SEXP endian) {
4+
5+
const int n_bytes = INTEGER(_n_bytes)[0];
6+
const R_xlen_t length = xlength(input);
7+
const bool big_endian = strcmp(CHAR(STRING_ELT(endian, 0)), "big") == 0 ? 1 : 0;
8+
const uint8_t* raw_buffer = (const uint8_t*) RAW(input);
9+
10+
SEXP data;
11+
const R_xlen_t data_length = length / n_bytes;
12+
R_xlen_t i;
13+
14+
// space for the converted output
15+
data = PROTECT(allocVector(REALSXP, data_length));
16+
double *p_data = REAL(data);
17+
18+
if(n_bytes != 2) {
19+
Rf_error("Only 2 byte bfloat16 is supported");
20+
}
21+
22+
for (i = 0; i < data_length; i++) {
23+
uint8_t b0 = raw_buffer[i * 2];
24+
uint8_t b1 = raw_buffer[i * 2 + 1];
25+
uint32_t f32_bits;
26+
if (big_endian) {
27+
f32_bits = (uint32_t)b0 << 24 | (uint32_t)b1 << 16;
28+
} else {
29+
f32_bits = (uint32_t)b0 | (uint32_t)b1 << 8;
30+
f32_bits <<= 16;
31+
}
32+
float f32_val;
33+
memcpy(&f32_val, &f32_bits, sizeof(float));
34+
// Magic scaling factor to round when converting bfloat16 to float32,
35+
// from https://stackoverflow.com/a/55290557
36+
p_data[i] = (double)f32_val * 1.001957f;
37+
}
38+
39+
if (!isNull(dims) && xlength(dims) > 0) {
40+
Rf_dimgets(data, dims);
41+
}
42+
43+
UNPROTECT(1);
44+
return(data);
45+
}

src/type_conversion.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "Rarr.h"
2+
3+
SEXP type_convert_bfloat(SEXP input, SEXP _n_bytes, SEXP dims, SEXP endian);

0 commit comments

Comments
 (0)