Skip to content

Commit

Permalink
Allow selecting serialization stragegy per writer and and optimize wr…
Browse files Browse the repository at this point in the history
…iting legth delimited fields, as well as loop unroll some functions.
  • Loading branch information
andersfugmann committed Jan 7, 2024
1 parent eef2b43 commit 3b11845
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 148 deletions.
64 changes: 31 additions & 33 deletions bench/bench.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,32 @@ module type Plugin_impl = sig
val show: t -> string
val equal: t -> t -> bool
val to_proto: t -> Ocaml_protoc_plugin.Writer.t
val to_proto': Ocaml_protoc_plugin.Writer.t -> t -> Ocaml_protoc_plugin.Writer.t
val from_proto_exn: Ocaml_protoc_plugin.Reader.t -> t
end
end

let make_tests (type v) (module Protoc: Protoc_impl) (module Plugin: Plugin_impl with type M.t = v) v_plugin =
let contents = Plugin.M.to_proto v_plugin in
let data = contents |> Ocaml_protoc_plugin.Writer.contents in
(* We need to reconstruct the data, as we might loose precision when using floats (32bit, compared to doubles) (64 bit) *)

(* Verify *)
let verify_identity ~mode data =
let writer = Plugin.M.to_proto' (Ocaml_protoc_plugin.Writer.init ~mode ()) data in
let data' = Plugin.M.from_proto_exn (Ocaml_protoc_plugin.Reader.create (Ocaml_protoc_plugin.Writer.contents writer)) in
let () = match Plugin.M.equal data data' with
| true -> ()
| false ->
eprintf "Orig: %s\n" (Plugin.M.show data);
eprintf "New: %s\n" (Plugin.M.show data');
failwith "Data not the same"
in
Ocaml_protoc_plugin.Writer.contents writer |> String.length,
Ocaml_protoc_plugin.Writer.unused_space writer
in
let size_normal, unused_normal = verify_identity ~mode:Ocaml_protoc_plugin.Writer.Balanced v_plugin in
let size_speed, unused_speed = verify_identity ~mode:Ocaml_protoc_plugin.Writer.Speed v_plugin in
let size_space, unused_space = verify_identity ~mode:Ocaml_protoc_plugin.Writer.Space v_plugin in
let data = Plugin.M.to_proto' (Ocaml_protoc_plugin.Writer.init ()) v_plugin |> Ocaml_protoc_plugin.Writer.contents in
let v_plugin = Plugin.M.from_proto_exn (Ocaml_protoc_plugin.Reader.create data) in
(* Assert decoding works *)
let v_protoc = Protoc.decode_pb_m (Pbrt.Decoder.of_string data) in
let protoc_encoder = Pbrt.Encoder.create () in
let () = Protoc.encode_pb_m v_protoc protoc_encoder in
Expand All @@ -36,13 +52,17 @@ let make_tests (type v) (module Protoc: Protoc_impl) (module Plugin: Plugin_impl
eprintf "New: %s\n" (Plugin.M.show v_plugin');
failwith "Data not the same"
in
printf "%16s: Data length: %5d /%5d (%b). Waste: %5d\n%!" (Plugin.M.name' ()) (String.length data) (String.length data_protoc) (Poly.equal v_plugin v_plugin') (Ocaml_protoc_plugin.Writer.unused contents);
printf "%-16s: %5d+%-5d(B) / %5d+%-5d(S) / %5d+%-5d(Sp) - %5d\n%!" (Plugin.M.name' ())
size_normal unused_normal size_speed unused_speed size_space unused_space (String.length data_protoc);


let open Bechamel in
let test_encode =
Test.make_grouped ~name:"Encode"
[
Test.make ~name:"Plugin" (Staged.stage @@ fun () -> Plugin.M.to_proto v_plugin);
Test.make ~name:"Plugin balanced" (Staged.stage @@ fun () -> Plugin.M.to_proto' Ocaml_protoc_plugin.Writer.(init ~mode:Balanced ()) v_plugin);
Test.make ~name:"Plugin speed" (Staged.stage @@ fun () -> Plugin.M.to_proto' Ocaml_protoc_plugin.Writer.(init ~mode:Speed ()) v_plugin);
Test.make ~name:"Plugin space" (Staged.stage @@ fun () -> Plugin.M.to_proto' Ocaml_protoc_plugin.Writer.(init ~mode:Space ()) v_plugin);
Test.make ~name:"Protoc" (Staged.stage @@ fun () -> Protoc.encode_pb_m v_protoc (Pbrt.Encoder.create ()))
]
in
Expand Down Expand Up @@ -106,16 +126,15 @@ let create_test_data ~depth () =
in
create_btree depth ()


let benchmark tests =
let open Bechamel in
let instances = Bechamel_perf.Instance.[ cpu_clock ] in
let cfg = Benchmark.cfg ~stabilize:true ~compaction:true () in
let cfg = Benchmark.cfg ~limit:2000 ~quota:(Time.second 5.0) ~kde:(Some 1000) ~stabilize:true ~compaction:false () in
Benchmark.all cfg instances tests

let analyze results =
let open Bechamel in
let ols = Analyze.ols ~bootstrap:0 ~r_square:true
let ols = Analyze.ols ~bootstrap:10 ~r_square:false
~predictors:[| Measure.run |] in
let results = Analyze.all ols Bechamel_perf.Instance.cpu_clock results in
Analyze.merge ols [ Bechamel_perf.Instance.cpu_clock ] [ results ]
Expand All @@ -141,38 +160,17 @@ let print_bench_results results =
img (window, results) |> eol |> output_image


let test_unroll () =
let open Bechamel in
let values = List.init 9 ~f:(fun idx -> Int64.shift_left 1L (idx*7)) in
let buffer = Bytes.create 10 in
List.mapi ~f:(fun index vl ->
let v = Int64.to_int_exn vl in
Test.make_grouped ~name:(Printf.sprintf "bits %d" (index*7)) [
Test.make ~name:"Varint unboxed unrolled" (Staged.stage @@ fun () ->
Ocaml_protoc_plugin.Writer.write_varint_unboxed buffer ~offset:0 v |> ignore);
Test.make ~name:"Varint unboxed reference" (Staged.stage @@ fun () ->
Ocaml_protoc_plugin.Writer.write_varint_unboxed_reference buffer ~offset:0 v |> ignore);

Test.make ~name:"Varint unrolled" (Staged.stage @@ fun () ->
Ocaml_protoc_plugin.Writer.write_varint buffer ~offset:0 vl |> ignore);
Test.make ~name:"Varint reference" (Staged.stage @@ fun () ->
Ocaml_protoc_plugin.Writer.write_varint_reference buffer ~offset:0 vl |> ignore);

]) values


let _ =
let v_plugin = create_test_data ~depth:2 () |> Option.value_exn in
test_unroll () @
[ make_tests (module Protoc.Bench) (module Plugin.Bench) v_plugin;
make_tests (module Protoc.Int64) (module Plugin.Int64) 27;
make_tests (module Protoc.Float) (module Plugin.Float) 27.0001;
make_tests (module Protoc.String) (module Plugin.String) "Benchmark";
make_tests (module Protoc.Enum) (module Plugin.Enum) Plugin.Enum.Enum.ED;

random_list ~len:100 ~f:(fun () -> Random.int 1000) () |> make_tests (module Protoc.Int64_list) (module Plugin.Int64_list);
random_list ~len:100 ~f:(fun () -> Random.float 1000.0) () |> make_tests (module Protoc.Float_list) (module Plugin.Float_list);
random_list ~len:100 ~f:random_string () |> make_tests (module Protoc.String_list) (module Plugin.String_list);
List.init 1000 ~f:(fun i -> i) |> make_tests (module Protoc.Int64_list) (module Plugin.Int64_list);
List.init 1000 ~f:(fun i -> Float.of_int i) |> make_tests (module Protoc.Float_list) (module Plugin.Float_list);
List.init 1000 ~f:(fun _ -> random_string ()) |> make_tests (module Protoc.String_list) (module Plugin.String_list);
(* random_list ~len:100 ~f:(fun () -> Plugin.Enum_list.Enum.ED) () |> make_tests (module Protoc.Enum_list) (module Plugin.Enum_list); *)
]
|> List.rev |> List.iter ~f:(fun test ->
Expand Down
2 changes: 1 addition & 1 deletion bench/float.proto
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
syntax = "proto3";

message M {
float i = 1;
double i = 1;
}
1 change: 1 addition & 0 deletions bench/protoc/dune
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@

(library
(name protoc)
(ocamlopt_flags :standard \ -unboxed-types)
(libraries pbrt))
63 changes: 25 additions & 38 deletions src/ocaml_protoc_plugin/serialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,6 @@ module S = Spec.Serialize
module C = S.C
open S

let rec size_of_field: type a. a spec -> a -> int = function
(* We could just assume 10 bytes for a varint to speed it up *)
| Double | Fixed64 | SFixed64 | Fixed64_int | SFixed64_int -> fun _ -> 8
| Float | Fixed32 | SFixed32 | Fixed32_int | SFixed32_int -> fun _ -> 4
| Int64 -> fun v -> Writer.varint_size (Int64.to_int v)
| UInt64 -> fun v -> Writer.varint_size (Int64.to_int v)
| SInt64 -> fun v -> Writer.varint_size (Int64.to_int v)

| Int32 -> fun v -> Writer.varint_size (Int32.to_int v)
| UInt32 -> fun v -> Writer.varint_size (Int32.to_int v)
| SInt32 -> fun v -> Writer.varint_size (Int32.to_int v)

| Int64_int -> Writer.varint_size
| UInt64_int -> Writer.varint_size
| Int32_int -> Writer.varint_size
| UInt32_int -> Writer.varint_size
| SInt64_int -> Writer.varint_size
| SInt32_int -> Writer.varint_size

| Bool -> let size = size_of_field Int64_int 1 in fun _ -> size
| String -> let size = size_of_field Int64_int in fun v -> let length = String.length v in (size length) + length
| Bytes -> let size = size_of_field Int64_int in fun v -> let length = Bytes.length v in (size length) + length
| Enum _ -> failwith "Enums must be converted to varint"
| Message _ -> failwith "Message sizes should not be pre-computed as we have a continuation and don't need to preallocate"

let field_type: type a. a spec -> int = function
| Int64 | UInt64 | SInt64 | Int32 | UInt32 | SInt32
| Int64_int | UInt64_int | Int32_int | UInt32_int | SInt64_int | SInt32_int
Expand Down Expand Up @@ -74,7 +49,6 @@ let write_varint_unboxed ~f v =
let writer = Writer.write_varint_unboxed in
Writer.write_value ~size ~writer v

(* Can only write a string *)
let write_string ~f v =
let v = f v in
let write_length = write_varint_unboxed ~f:String.length v in
Expand All @@ -83,9 +57,6 @@ let write_string ~f v =
write_length t;
Writer.write_value ~size:(String.length v) ~writer:write_string v t

let write_message ~f v writer =
Writer.write_length_delimited_value ~write:f v writer

let id x = x
let (@@) a b = fun v -> b (a v)

Expand Down Expand Up @@ -117,12 +88,28 @@ let write_value : type a. a spec -> a -> Writer.t -> unit = function
| String -> write_string ~f:id
| Bytes -> write_string ~f:Bytes.unsafe_to_string
| Enum f -> write_varint_unboxed ~f
| Message to_proto -> write_message ~f:(fun v writer -> to_proto writer v |> ignore)
| Message to_proto ->
(*
fun v writer ->
let cont = Writer.write_length_delimited_value_cont writer in
let _ = to_proto writer v in
cont ()
*)
Writer.write_length_delimited_value ~write:to_proto

(** Optimized when the value is given in advance, and the continuation is expected to be called multiple times *)
let write_value_const : type a. a spec -> a -> Writer.t -> unit = fun spec v ->
let write_value = write_value spec in
let writer = Writer.init () in
write_value v writer;
let data = Writer.contents writer in
let size = String.length data in
Writer.write_value ~size ~writer:Writer.write_string data

let write_field_header: 'a spec -> int -> Writer.t -> unit = fun spec index ->
let field_type = field_type spec in
let header = (index lsl 3) + field_type in
write_value Int64_int header
write_value_const Int64_int header

let write_field: type a. a spec -> int -> a -> Writer.t -> unit = fun spec index ->
let write_field_header = write_field_header spec index in
Expand All @@ -137,11 +124,9 @@ let is_scalar: type a. a spec -> bool = function
| Message _ -> false
| _ -> true

(* Try remove the fold et. al. *)
let rec write: type a. a compound -> Writer.t -> a -> unit = function
| Repeated (index, spec, Packed) when is_scalar spec -> begin
let write = write_value spec in
let write vs writer = List.iter ~f:(fun v -> write v writer) vs in
let write writer vs = List.iter ~f:(fun v -> write_value spec v writer) vs in
let write_header = write_field_header String index in
fun writer vs ->
match vs with
Expand Down Expand Up @@ -183,7 +168,6 @@ let rec write: type a. a compound -> Writer.t -> a -> unit = function
write (Basic (index, spec, None)) writer v
end

(** Allow emitted code to present a protobuf specification. *)
let rec serialize : type a. (a, Writer.t) compound_list -> Writer.t -> a = function
| Nil -> fun writer -> writer
| Cons (compound, rest) ->
Expand All @@ -198,12 +182,15 @@ let in_extension_ranges extension_ranges index =

let serialize extension_ranges spec =
let serialize = serialize spec in
fun extensions writer ->
List.iter ~f:(function
match extension_ranges with
| [] -> fun _ -> serialize
| extension_ranges ->
fun extensions writer ->
List.iter ~f:(function
| (index, field) when in_extension_ranges extension_ranges index -> Writer.write_field writer index field
| _ -> ()
) extensions;
serialize writer
serialize writer


let%expect_test "zigzag encoding" =
Expand Down
Loading

0 comments on commit 3b11845

Please sign in to comment.