From 599feadad39546aa72693964c7581006922682ff Mon Sep 17 00:00:00 2001 From: akiroz Date: Sun, 8 May 2022 17:20:33 +0800 Subject: [PATCH] port to zig 0.9 --- README.md | 2 +- build.zig | 41 +++++++++++------ src/client.zig | 27 ++++++----- src/config.zig | 8 ++-- src/driver.zig | 32 ++++++------- src/mqtt.zig | 121 +++++++++++++++++++++++++++---------------------- src/server.zig | 63 +++++++++++++------------ 7 files changed, 160 insertions(+), 134 deletions(-) diff --git a/README.md b/README.md index c60a2cf..d02f6b5 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Zika -![](https://img.shields.io/badge/zig-0.8.1-informational) +![](https://img.shields.io/badge/zig-0.9.1-informational) IP Tunneling over MQTT diff --git a/build.zig b/build.zig index beba4b9..d986ff2 100644 --- a/build.zig +++ b/build.zig @@ -1,17 +1,32 @@ const builtin = @import("builtin"); -const Builder = @import("std").build.Builder; +const std = @import("std"); +const LibExeObjStep = std.build.LibExeObjStep; -pub fn build(b: *Builder) void { - const server = b.addExecutable("zika-server", "src/server.zig"); - const client = b.addExecutable("zika-client", "src/client.zig"); - server.addIncludeDir("/usr/local/include"); - client.addIncludeDir("/usr/local/include"); - server.linkSystemLibrary("mosquitto"); - client.linkSystemLibrary("mosquitto"); - if (builtin.target.isDarwin()) { - server.linkSystemLibrary("pcap"); - client.linkSystemLibrary("pcap"); +fn commonOpts(exe: *LibExeObjStep) *LibExeObjStep { + exe.setBuildMode(exe.builder.standardReleaseOptions()); + if (builtin.target.isDarwin()) { // macOS + if (builtin.cpu.arch == .aarch64) { // Apple + exe.addIncludeDir("/opt/homebrew/include"); + exe.addLibPath("/opt/homebrew/lib"); + } else { // Intel + exe.addIncludeDir("/usr/local/include"); + } + exe.linkSystemLibrary("pcap"); + } else { // Linux + exe.addIncludeDir("/usr/include"); + exe.addLibPath("/usr/lib"); + exe.linkLibC(); + // Support down to Ubuntu 18 Bionic + exe.setTarget(.{ .glibc_version = .{ .major = 2, .minor = 27, .patch = 0 } }); } - //server.install(); - client.install(); + exe.linkSystemLibrary("mosquitto"); + return exe; } + +pub fn build(b: *std.build.Builder) void { + const server = commonOpts(b.addExecutable("zika-server", "src/server.zig")); + server.install(); + + const client = commonOpts(b.addExecutable("zika-client", "src/client.zig")); + client.install(); +} \ No newline at end of file diff --git a/src/client.zig b/src/client.zig index b48c5dc..332ba26 100644 --- a/src/client.zig +++ b/src/client.zig @@ -24,7 +24,7 @@ pub fn Tunnel(comptime T: type) type { const Conf = config.ClientTunnel; conf: Conf, - alloc: *Allocator, + alloc: Allocator, ifce: *NetInterface(T), mqtt: *Mqtt(T), @@ -33,7 +33,7 @@ pub fn Tunnel(comptime T: type) type { up_topic_cstr: [:0]u8, dn_topic_cstr: [:0]u8, - pub fn create(alloc: *Allocator, ifce: *NetInterface(T), broker: *Mqtt(T), conf: Conf) !*Self { + pub fn create(alloc: Allocator, ifce: *NetInterface(T), broker: *Mqtt(T), conf: Conf) !*Self { const self = try alloc.create(Self); self.conf = conf; self.alloc = alloc; @@ -83,6 +83,7 @@ pub const Client = struct { }; arena: ArenaAllocator, + alloc: Allocator, ifce: ?*NetInterface(*Self), mqtt: ?*Mqtt(*Self), tunnels: []*Tunnel(*Self), @@ -91,27 +92,27 @@ pub const Client = struct { disconnect_callback: ?ConnectHandler(*Mqtt(*Self)), message_hook: ?MessageHook(*Mqtt(*Self)), - pub fn init(parent_alloc: *Allocator, conf: * const Config) !*Self { + pub fn init(parent_alloc: Allocator, conf: * const Config) !*Self { const client_conf = conf.client orelse { std.log.err("missing client config", .{}); return Error.ConfigMissing; }; - var arena = ArenaAllocator.init(parent_alloc); - const self = try arena.allocator.create(Self); - self.arena = arena; + const self = try parent_alloc.create(Self); + self.arena = ArenaAllocator.init(parent_alloc); + self.alloc = self.arena.allocator(); self.ifce = null; self.mqtt = null; errdefer self.deinit(); - const alloc = &self.arena.allocator; std.log.info("== Client Config =================================", .{}); - self.ifce = try NetInterface(*Self).init(alloc, conf, self, @ptrCast(driver.PacketHandler(*Self), &up)); + self.ifce = try NetInterface(*Self).init(self.alloc, conf, self, @ptrCast(driver.PacketHandler(*Self), &up)); const max_subs = client_conf.tunnels.len; - self.mqtt = try Mqtt(*Self).init(alloc, conf, self, @ptrCast(mqtt.PacketHandler(*Self), &down), max_subs); - self.tunnels = try alloc.alloc(*Tunnel(*Self), client_conf.tunnels.len); + self.mqtt = try Mqtt(*Self).init(self.alloc, conf, self, @ptrCast(mqtt.PacketHandler(*Self), &down), max_subs); + std.log.info("Tunnels: {d}", .{client_conf.tunnels.len}); + self.tunnels = try self.alloc.alloc(*Tunnel(*Self), client_conf.tunnels.len); for (client_conf.tunnels) |tunnel, idx| { self.tunnels[idx] = try Tunnel(*Self).create( - alloc, + self.alloc, self.ifce orelse unreachable, self.mqtt orelse unreachable, .{ @@ -183,9 +184,11 @@ pub const Client = struct { pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - const alloc = &gpa.allocator; + var alloc = gpa.allocator(); const conf_path = try std.fs.cwd().realpathAlloc(alloc, "zika_config.json"); const conf = try config.get(alloc, conf_path); const client = try Client.init(alloc, &conf); + defer client.deinit(); + defer alloc.destroy(client); try client.run(); } diff --git a/src/config.zig b/src/config.zig index f75431c..7065918 100644 --- a/src/config.zig +++ b/src/config.zig @@ -9,10 +9,10 @@ pub const MqttOptions = struct { username: ?[]const u8 = null, password: ?[]const u8 = null, - ca: ?[]const u8 = null, + ca_file: ?[]const u8 = null, tls_insecure: bool = false, - key: ?[]const u8 = null, - cert: ?[]const u8 = null, + key_file: ?[]const u8 = null, + cert_file: ?[]const u8 = null, }; pub const MqttBroker = struct { @@ -64,7 +64,7 @@ pub const Config = struct { client: ?ClientConfig = null, }; -pub fn get(alloc: *Allocator, file: []u8) !Config { +pub fn get(alloc: Allocator, file: []u8) !Config { const cfg_file = try std.fs.openFileAbsolute(file, .{ .read = true }); defer cfg_file.close(); diff --git a/src/driver.zig b/src/driver.zig index 3836989..c717864 100644 --- a/src/driver.zig +++ b/src/driver.zig @@ -10,7 +10,6 @@ const Net = @cImport({ const Pcap = @cImport(@cInclude("pcap/pcap.h")); const Allocator = std.mem.Allocator; -const ArenaAllocator = std.heap.ArenaAllocator; const Ip4Address = std.net.Ip4Address; const Address = std.net.Address; pub fn PacketHandler(comptime T: type) type { @@ -31,13 +30,13 @@ fn PcapDriver(comptime T: type) type { NotInitialized, }; - alloc: *Allocator, + alloc: Allocator, user: T, handler: PacketHandler(T), pcap: ?*Pcap.pcap_t, pcap_header: [4]u8, - pub fn init(alloc: *Allocator, conf: *const Config, user: T, handler: PacketHandler(T)) !*Self { + pub fn init(alloc: Allocator, conf: *const Config, user: T, handler: PacketHandler(T)) !*Self { const pcap_conf = conf.driver.pcap orelse { std.log.err("missing pcap driver config", .{}); return Error.ConfigMissing; @@ -99,7 +98,7 @@ fn PcapDriver(comptime T: type) type { } fn recv(self_ptr: [*c]u8, hdr: [*c]const Pcap.pcap_pkthdr, pkt: [*c]const u8) callconv(.C) void { - const self = @intToPtr(*Self, @ptrToInt(self_ptr)); + const self = @ptrCast(*Self, @alignCast(@alignOf(*Self), self_ptr.?)); self.handler.*(self.user, pkt[4..hdr.*.len]); } @@ -135,7 +134,7 @@ fn TunDriver(comptime T: type) type { tun: ?std.fs.File = null, buf: []u8, - pub fn init(alloc: *Allocator, conf: * const Config, user: T, handler: PacketHandler(T)) !*Self { + pub fn init(alloc: Allocator, conf: *const Config, user: T, handler: PacketHandler(T)) !*Self { const tun_conf = conf.driver.tun orelse { std.log.err("missing tun driver config", .{}); return Error.ConfigMissing; @@ -243,16 +242,14 @@ pub fn NetInterface(comptime T: type) type { return struct { const Self = @This(); - - arena: ArenaAllocator, + alloc: Allocator, local_ip: u32, driver: ?*Driver(T) = null, - pub fn init(alloc: *Allocator, conf: * const Config, user: T, handler: PacketHandler(T)) !*Self { - var arena = ArenaAllocator.init(alloc); - const self = try arena.allocator.create(Self); - self.arena = arena; + pub fn init(alloc: Allocator, conf: *const Config, user: T, handler: PacketHandler(T)) !*Self { + const self = try alloc.create(Self); errdefer self.deinit(); + self.alloc = alloc; self.local_ip = (try Ip4Address.parse(conf.driver.local_addr, 0)).sa.addr; self.driver = try Driver(T).init(alloc, conf, user, handler); std.log.info("Local IP: {s}", .{conf.driver.local_addr}); @@ -261,7 +258,6 @@ pub fn NetInterface(comptime T: type) type { pub fn deinit(self: *Self) void { self.driver.?.deinit(); - self.arena.deinit(); } pub fn run(self: *Self) !void { @@ -274,7 +270,7 @@ pub fn NetInterface(comptime T: type) type { hdr.src = src; hdr.dst = self.local_ip; hdr.cksum = 0; // Zero before recalc - hdr.cksum = try cksum(&self.arena.allocator, pkt[0..payload_offset], 0); + hdr.cksum = try self.cksum(pkt[0..payload_offset], 0); switch (hdr.proto) { 6, 17 => { // TCP / UDP var pseudo_buf = std.mem.zeroes([@sizeOf(PseudoHeader)/2]u16); @@ -288,7 +284,7 @@ pub fn NetInterface(comptime T: type) type { const cksum_offset = payload_offset + if (hdr.proto == 6) @as(usize, 16) else 6; const cksum_slice = pkt[cksum_offset..cksum_offset+2]; std.mem.set(u8, cksum_slice, 0); // Zero before recalc - var sum = try cksum(&self.arena.allocator, pkt[payload_offset..], pseudo_sum); + var sum = try self.cksum(pkt[payload_offset..], pseudo_sum); std.mem.copy(u8, cksum_slice, @ptrCast(*[2]u8, &sum)); }, else => {}, // No special handling @@ -296,10 +292,10 @@ pub fn NetInterface(comptime T: type) type { try self.driver.?.write(pkt); } - fn cksum(alloc: *Allocator, buf: []u8, carry: u32) !u16 { + fn cksum(self: *Self, buf: []u8, carry: u32) !u16 { var sum: u32 = carry; - const buf2 = try alloc.alloc(u16, buf.len/2 + 1); - defer alloc.free(buf2); + const buf2 = try self.alloc.alloc(u16, buf.len/2 + 1); + defer self.alloc.free(buf2); std.mem.set(u16, buf2, 0); const buf2_ptr = @ptrCast([*]u8, buf2); std.mem.copy(u8, buf2_ptr[0..buf2.len*2], buf); @@ -307,7 +303,7 @@ pub fn NetInterface(comptime T: type) type { while (sum > 0xffff) { sum = (sum & 0xffff) + (sum >> 16); } - return @intCast(u16, sum ^ 0xffff); + return @truncate(u16, sum ^ 0xffff); } }; } diff --git a/src/mqtt.zig b/src/mqtt.zig index f7add30..56f57df 100644 --- a/src/mqtt.zig +++ b/src/mqtt.zig @@ -11,7 +11,7 @@ const ArenaAllocator = std.heap.ArenaAllocator; const Config = config.Config; pub fn PacketHandler(comptime T: type) type { - // user, message + // user, topic, payload return *const fn (T, []const u8, []const u8) void; } @@ -36,9 +36,10 @@ pub fn Client(comptime T: type) type { SubscribeFailed, }; - alloc: *Allocator, + arena: ArenaAllocator, + alloc: Allocator, mosq: *Mosq.mosquitto, - mosq_thread: ?*std.Thread, + mosq_thread: ?std.Thread, conf: Conf, host_cstr: []const u8, @@ -51,19 +52,20 @@ pub fn Client(comptime T: type) type { disconnect_callback: ?ConnectHandler(T), subscribe_cond: std.Thread.Condition, - pub fn init(alloc: *Allocator, conf: Conf, user: T, handler: PacketHandler(T), max_subs: usize) !*Self { - const self = try alloc.create(Self); - self.alloc = alloc; + pub fn init(parent_alloc: Allocator, conf: Conf, user: T, handler: PacketHandler(T), max_subs: usize) !*Self { + const self = try parent_alloc.create(Self); + self.arena = ArenaAllocator.init(parent_alloc); + self.alloc = self.arena.allocator(); self.conf = conf; - self.host_cstr = try std.cstr.addNullByte(alloc, conf.host); + self.host_cstr = try std.cstr.addNullByte(self.alloc, conf.host); self.subs_count = 0; - self.subs = try alloc.alloc([:0]u8, max_subs); + self.subs = try self.alloc.alloc([:0]u8, max_subs); self.user = user; self.msg_callback = handler; self.connect_count = 0; self.connect_callback = null; self.disconnect_callback = null; - self.subscribe_cond = std.Thread.Condition {}; + self.subscribe_cond = std.Thread.Condition{}; self.mosq = Mosq.mosquitto_new(null, true, self) orelse { return Error.CreateFailed; }; @@ -77,21 +79,26 @@ pub fn Client(comptime T: type) type { } if (opts.username) |username| { - const username_str = try std.cstr.addNullByte(alloc, username); - const password_str: [*c]const u8 = if (opts.password) |password| (try std.cstr.addNullByte(alloc, password)) else null; - rc = Mosq.mosquitto_username_pw_set(self.mosq, username_str, password_str); + rc = Mosq.mosquitto_username_pw_set( + self.mosq, + (try std.cstr.addNullByte(self.alloc, username)).ptr, + if (opts.password) |p| (try std.cstr.addNullByte(self.alloc, p)).ptr else null + ); if (rc != Mosq.MOSQ_ERR_SUCCESS) { std.log.err("mosquitto_username_pw_set: {s}", .{Mosq.mosquitto_strerror(rc)}); return Error.OptionError; } } - if (opts.ca != null or opts.cert != null) { - const ca_path: [*c]const u8 = if (opts.ca) |ca| (try std.cstr.addNullByte(alloc, ca)) else null; - const key_path: [*c]const u8 = if (opts.key) |key| (try std.cstr.addNullByte(alloc, key)) else null; - const cert_path: [*c]const u8 = if (opts.cert) |cert| (try std.cstr.addNullByte(alloc, cert)) else null; - - rc = Mosq.mosquitto_tls_set(self.mosq, ca_path, null, cert_path, key_path, null); + if (opts.ca_file != null or opts.cert_file != null) { + rc = Mosq.mosquitto_tls_set( + self.mosq, + if (opts.ca_file) |ca| (try std.cstr.addNullByte(self.alloc, ca)).ptr else null, + null, // capath + if (opts.cert_file) |cert| (try std.cstr.addNullByte(self.alloc, cert)).ptr else null, + if (opts.key_file) |key| (try std.cstr.addNullByte(self.alloc, key)).ptr else null, + null // pw_callback + ); if (rc != Mosq.MOSQ_ERR_SUCCESS) { std.log.err("mosquitto_tls_set: {s}", .{Mosq.mosquitto_strerror(rc)}); return Error.OptionError; @@ -118,13 +125,14 @@ pub fn Client(comptime T: type) type { pub fn deinit(self: *Self) void { _ = Mosq.mosquitto_disconnect(self.mosq); - if(self.mosq_thread) |t| std.Thread.wait(t); + if (self.mosq_thread) |t| t.join(); Mosq.mosquitto_destroy(self.mosq); + self.arena.deinit(); } pub fn connect(self: *Self) !void { _ = Mosq.mosquitto_threaded_set(self.mosq, true); - self.mosq_thread = try std.Thread.spawn(thread_main, self); + self.mosq_thread = try std.Thread.spawn(.{}, thread_main, .{ self }); const keepalive = self.conf.opts.keepalive_interval; const rc = Mosq.mosquitto_connect_async(self.mosq, @ptrCast([*c]const u8, self.host_cstr), self.conf.port, keepalive); @@ -136,86 +144,89 @@ pub fn Client(comptime T: type) type { fn thread_main(self: *Self) !void { const keepalive = self.conf.opts.keepalive_interval; - const rc = Mosq.mosquitto_loop_forever(self.mosq, keepalive*1000, 1); + const rc = Mosq.mosquitto_loop_forever(self.mosq, keepalive * 1000, 1); if (rc != Mosq.MOSQ_ERR_SUCCESS) { std.log.err("mosquitto_loop_forever: {s}", .{Mosq.mosquitto_strerror(rc)}); return Error.ConnectFailed; } } - fn onConnect(_: ?*Mosq.mosquitto, self_ptr: ?*c_void, rc: c_int) callconv(.C) void { - const self = @intToPtr(*Self, @ptrToInt(self_ptr orelse unreachable)); + fn onConnect(mosq: ?*Mosq.mosquitto, self_ptr: ?*anyopaque, rc: c_int) callconv(.C) void { + _ = mosq; + const self = @ptrCast(*Self, @alignCast(@alignOf(*Self), self_ptr.?)); std.log.info("connect[{d}]: {s}", .{ self.conf.nth, Mosq.mosquitto_strerror(rc) }); - self.connect_count += 1; if (self.connect_callback) |cb| { cb.*(self.user, self.conf.nth, self.connect_count); } var i: usize = 0; - while(i < self.subs_count) : (i += 1) { - const sub_rc = Mosq.mosquitto_subscribe(self.mosq, null, self.subs[i].ptr, 0); - if(sub_rc != Mosq.MOSQ_ERR_SUCCESS and rc != Mosq.MOSQ_ERR_NO_CONN) { + while (i < self.subs_count) : (i += 1) { + const sub_rc = Mosq.mosquitto_subscribe(mosq, null, self.subs[i].ptr, 0); + if (sub_rc != Mosq.MOSQ_ERR_SUCCESS and rc != Mosq.MOSQ_ERR_NO_CONN) { std.log.err("mosquitto_subscribe[{d}]: {s}", .{ self.conf.nth, Mosq.mosquitto_strerror(sub_rc) }); } } } - fn onDisconnect(_: ?*Mosq.mosquitto, self_ptr: ?*c_void, rc: c_int) callconv(.C) void { - const self = @intToPtr(*Self, @ptrToInt(self_ptr orelse unreachable)); + fn onDisconnect(mosq: ?*Mosq.mosquitto, self_ptr: ?*anyopaque, rc: c_int) callconv(.C) void { + _ = mosq; + const self = @ptrCast(*Self, @alignCast(@alignOf(*Self), self_ptr.?)); std.log.info("disconnect[{d}]: {s}", .{ self.conf.nth, Mosq.mosquitto_strerror(rc) }); if (self.disconnect_callback) |cb| { cb.*(self.user, self.conf.nth, self.connect_count); } } - fn onSubscribe(_: ?*Mosq.mosquitto, self_ptr: ?*c_void, _mid: c_int, qos_len: c_int, qos_arr: [*c]const c_int) callconv(.C) void { - const self = @intToPtr(*Self, @ptrToInt(self_ptr orelse unreachable)); + fn onSubscribe(mosq: ?*Mosq.mosquitto, self_ptr: ?*anyopaque, mid: c_int, qos_len: c_int, qos_arr: [*c]const c_int) callconv(.C) void { + _ = mosq; + _ = mid; + const self = @ptrCast(*Self, @alignCast(@alignOf(*Self), self_ptr.?)); const qos = qos_arr[0..@intCast(usize, qos_len)]; - for(qos) |q| if(q != 0) std.log.warn("subscribe[{d}]: {d}", .{ self.conf.nth, 1 }); + for (qos) |q| if (q != 0) std.log.warn("subscribe[{d}]: {d}", .{ self.conf.nth, q }); self.subscribe_cond.broadcast(); } - fn onMessage(_: ?*Mosq.mosquitto, self_ptr: ?*c_void, msg: [*c]const Mosq.mosquitto_message) callconv(.C) void { + fn onMessage(mosq: ?*Mosq.mosquitto, self_ptr: ?*anyopaque, msg: [*c]const Mosq.mosquitto_message) callconv(.C) void { + _ = mosq; + const self = @ptrCast(*Self, @alignCast(@alignOf(*Self), self_ptr.?)); const topic = msg.*.topic[0..std.mem.len(msg.*.topic)]; // std.log.info("message: {s}", .{topic}); - const self = @intToPtr(*Self, @ptrToInt(self_ptr orelse unreachable)); - const len = @intCast(usize, msg.*.payloadlen); - const payload = @ptrCast([*c]const u8, msg.*.payload)[0..len]; + const payload = @ptrCast([*]u8, msg.*.payload.?)[0..@intCast(usize, msg.*.payloadlen)]; self.msg_callback.*(self.user, topic, payload); } - fn onLog(_: ?*Mosq.mosquitto, _: ?*c_void, level: c_int, msg: [*c]const u8) callconv(.C) void { - std.log.info("mosq({d}): {s}", .{level, msg}); + fn onLog(mosq: *Mosq.mosquitto, self: *Self, level: c_int, msg: [*]const u8) callconv(.C) void { + _ = mosq; + _ = self; + std.log.info("mosq({d}): {s}", .{ level, msg }); } pub fn subscribe(self: *Self, topic: [:0]u8, persistent: bool) !void { - if(persistent) { + if (persistent) { if (self.subs_count < self.subs.len) { self.subs[self.subs_count] = topic; self.subs_count += 1; } else { - std.log.err("subscribe[{d}]: persistent subscription list is full", .{ self.conf.nth }); + std.log.err("subscribe[{d}]: persistent subscription list is full", .{self.conf.nth}); return Error.SubscribeFailed; } } else { const rc = Mosq.mosquitto_subscribe(self.mosq, null, topic.ptr, 0); - if(rc != Mosq.MOSQ_ERR_SUCCESS and rc != Mosq.MOSQ_ERR_NO_CONN) { + if (rc != Mosq.MOSQ_ERR_SUCCESS and rc != Mosq.MOSQ_ERR_NO_CONN) { std.log.err("mosquitto_subscribe[{d}]: {s}", .{ self.conf.nth, Mosq.mosquitto_strerror(rc) }); return Error.SubscribeFailed; } } } - // NOTE: topic must be null-terminated pub fn unsubscribe(self: *Self, topic: [:0]u8) void { const rc = Mosq.mosquitto_unsubscribe(self.mosq, null, topic.ptr); - if(rc != Mosq.MOSQ_ERR_SUCCESS) { + if (rc != Mosq.MOSQ_ERR_SUCCESS) { std.log.err("mosquitto_unsubscribe[{d}]: {s}", .{ self.conf.nth, Mosq.mosquitto_strerror(rc) }); } } - // NOTE: topic must be null-terminated pub fn publish(self: *Self, topic: [:0]u8, msg: []u8) bool { const rc = Mosq.mosquitto_publish(self.mosq, null, topic.ptr, @intCast(c_int, msg.len), msg.ptr, 0, false); return rc == Mosq.MOSQ_ERR_SUCCESS; @@ -226,22 +237,24 @@ pub fn Client(comptime T: type) type { pub fn Mqtt(comptime T: type) type { return struct { const Self = @This(); - - arena: ArenaAllocator, + alloc: Allocator, clients: []*Client(T), - pub fn init(alloc: *Allocator, conf: *const Config, user: T, handler: PacketHandler(T), max_subs: usize) !*Self { - var arena = ArenaAllocator.init(alloc); - const self = try arena.allocator.create(Self); + pub fn init(alloc: Allocator, conf: *const Config, user: T, handler: PacketHandler(T), max_subs: usize) !*Self { + const self = try alloc.create(Self); + self.alloc = alloc; errdefer self.deinit(); - self.arena = arena; - self.clients = try arena.allocator.alloc(*Client(T), conf.mqtt.brokers.len); - + self.clients = try alloc.alloc(*Client(T), conf.mqtt.brokers.len); for (conf.mqtt.brokers) |broker, idx| { const opts = broker.options orelse conf.mqtt.options; // idx, broker.host, broker.port, opts - self.clients[idx] = try Client(T).init(&arena.allocator, .{ .nth = idx, .host = broker.host, .port = broker.port, .opts = opts }, user, handler, max_subs); + self.clients[idx] = try Client(T).init(alloc, .{ + .nth = idx, + .host = broker.host, + .port = broker.port, + .opts = opts + }, user, handler, max_subs); } return self; @@ -249,7 +262,7 @@ pub fn Mqtt(comptime T: type) type { pub fn deinit(self: *Self) void { for (self.clients) |client| client.deinit(); - self.arena.deinit(); + self.alloc.free(self.clients); } pub fn setConnectCallback(self: *Self, cb: ConnectHandler(T)) void { diff --git a/src/server.zig b/src/server.zig index d507706..da07080 100644 --- a/src/server.zig +++ b/src/server.zig @@ -19,64 +19,62 @@ pub const Server = struct { const IpCache = std.AutoHashMap(u128, u32); const IdCache = std.AutoHashMap(u32, u128); const TopicCache = std.AutoHashMap(u32, [:0]u8); - const Error = error { - ConfigMissing - }; + const Error = error{ConfigMissing}; arena: ArenaAllocator, + alloc: Allocator, ifce: ?*NetInterface(*Self), mqtt: ?*Mqtt(*Self), - + pool_start: Ip4Address, pool_end: Ip4Address, pool_next_alloc: u32, - + id_len: u8, ip_cache: IpCache, id_cache: IdCache, - + b64_len: usize, topic: []const u8, topic_cache: TopicCache, - pub fn init(parent_alloc: *Allocator, conf: *const Config) !*Self { + pub fn init(parent_alloc: Allocator, conf: *const Config) !*Self { const server_conf = conf.server orelse { std.log.err("missing server config", .{}); return Error.ConfigMissing; }; - var arena = ArenaAllocator.init(parent_alloc); - const self = try arena.allocator.create(Self); - self.arena = arena; + const self = try parent_alloc.create(Self); + self.arena = ArenaAllocator.init(parent_alloc); + self.alloc = self.arena.allocator(); self.ifce = null; self.mqtt = null; errdefer self.deinit(); - const alloc = &self.arena.allocator; - + self.pool_start = try Ip4Address.parse(server_conf.pool_start, 0); self.pool_end = try Ip4Address.parse(server_conf.pool_end, 0); self.pool_next_alloc = self.pool_start.sa.addr; self.id_len = server_conf.id_length; - self.ip_cache = IpCache.init(alloc); - self.id_cache = IdCache.init(alloc); + self.ip_cache = IpCache.init(self.alloc); + self.id_cache = IdCache.init(self.alloc); self.b64_len = Base64UrlEncoder.calcSize(self.id_len); self.topic = server_conf.topic; - self.topic_cache = TopicCache.init(alloc); + self.topic_cache = TopicCache.init(self.alloc); std.log.info("== Server Config =================================", .{}); std.log.info("ID Length: {d}", .{server_conf.id_length}); std.log.info("Topic: {s}", .{server_conf.topic}); - std.log.info("IP Pool: {s} - {s}", .{server_conf.pool_start, server_conf.pool_end}); - self.ifce = try NetInterface(*Self).init(alloc, conf, self, @ptrCast(driver.PacketHandler(*Self), &up)); - self.mqtt = try Mqtt(*Self).init(alloc, conf, self, @ptrCast(mqtt.PacketHandler(*Self), &down), 1); + std.log.info("IP Pool: {s} - {s}", .{ server_conf.pool_start, server_conf.pool_end }); + self.ifce = try NetInterface(*Self).init(self.alloc, conf, self, @ptrCast(driver.PacketHandler(*Self), &up)); + self.mqtt = try Mqtt(*Self).init(self.alloc, conf, self, @ptrCast(mqtt.PacketHandler(*Self), &down), 1); std.log.info("==================================================", .{}); return self; } pub fn deinit(self: *Self) void { - if(self.mqtt) |m| m.deinit(); - if(self.ifce) |i| i.deinit(); + if (self.mqtt) |m| m.deinit(); + if (self.ifce) |i| i.deinit(); self.arena.deinit(); } @@ -86,27 +84,26 @@ pub const Server = struct { } fn allocIp(self: *Self, id: u128) !u32 { - const alloc = &self.arena.allocator; const next_addr = self.pool_next_alloc; self.pool_next_alloc += 1; - if(self.pool_next_alloc > self.pool_end.sa.addr) { + if (self.pool_next_alloc > self.pool_end.sa.addr) { self.pool_next_alloc = self.pool_start.sa.addr; } - if(self.id_cache.fetchRemove(next_addr)) |entry| { - alloc.free(self.b64_cache.fetchRemove(next_addr).?.value); + if (self.id_cache.fetchRemove(next_addr)) |entry| { + self.alloc.free(self.b64_cache.fetchRemove(next_addr).?.value); _ = self.ip_cache.remove(entry.value); } try self.ip_cache.put(id, next_addr); try self.id_cache.put(next_addr, id); - - var b64_id = try alloc.alloc(u8, self.b64_len); - defer alloc.free(b64_id); + + var b64_id = try self.alloc.alloc(u8, self.b64_len); + defer self.alloc.free(b64_id); var id_bytes = std.mem.toBytes(id); id_bytes.len = self.id_len; Base64UrlEncoder.encode(b64_id, id_bytes); - const topic = try std.fmt.allocPrint(alloc, "{s}/{s}", .{self.topic, b64_id}); + const topic = try std.fmt.allocPrint(self.alloc, "{s}/{s}", .{ self.topic, b64_id }); try self.topic_cache.put(next_addr, topic); return next_addr; @@ -114,7 +111,7 @@ pub const Server = struct { fn up(self: *Self, pkt: []u8) void { const hdr = @ptrCast(*IpHeader, pkt); - if(self.topic_cache.get(hdr.dst)) |topic| { + if (self.topic_cache.get(hdr.dst)) |topic| { self.mqtt.?.send(topic, pkt) catch |err| { std.log.warn("up: {s}", .{err}); }; @@ -122,22 +119,24 @@ pub const Server = struct { } fn down(self: *Self, topic: []const u8, msg: []u8) void { + _ = topic; var id: u128 = 0; std.mem.copy(u8, std.mem.toBytes(id)[0..], msg[0..self.id_len]); - if(self.ip_cache.get(id)) |addr| { + if (self.ip_cache.get(id)) |addr| { self.ifce.?.inject(addr, msg[self.id_len..]) catch |err| { std.log.warn("down: {s}", .{err}); }; } } - }; pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - const alloc = &gpa.allocator; + var alloc = gpa.allocator(); const conf_path = try std.fs.cwd().realpathAlloc(alloc, "zika_config.json"); const conf = try config.get(alloc, conf_path); const server = try Server.init(alloc, &conf); + defer server.deinit(); + defer alloc.destroy(server); try server.run(); }