Skip to content

Commit 064d7f6

Browse files
authored
Merge pull request #87 from thatstoasty/udp
UDP Socket support
2 parents b41e88e + bd14fc5 commit 064d7f6

File tree

10 files changed

+685
-32
lines changed

10 files changed

+685
-32
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ jobs:
1717
magic run test
1818
magic run integration_tests_py
1919
magic run integration_tests_external
20+
magic run integration_tests_udp

.gitignore

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ install_id
1414
# Rattler
1515
output
1616

17+
# integration tests
18+
udp_client.DSYM
19+
udp_server.DSYM
20+
__pycache__
21+
1722
# misc
1823
.vscode
19-
20-
__pycache__

lightbug_http/libc.mojo

Lines changed: 269 additions & 1 deletion
Large diffs are not rendered by default.

lightbug_http/net.mojo

Lines changed: 254 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ from lightbug_http.libc import (
2020
AF_INET,
2121
AF_INET6,
2222
SOCK_STREAM,
23+
SOCK_DGRAM,
2324
SOL_SOCKET,
2425
SO_REUSEADDR,
2526
SO_REUSEPORT,
@@ -71,7 +72,7 @@ trait Connection(Movable):
7172
fn teardown(mut self) raises:
7273
...
7374

74-
fn local_addr(mut self) -> TCPAddr:
75+
fn local_addr(self) -> TCPAddr:
7576
...
7677

7778
fn remote_addr(self) -> TCPAddr:
@@ -135,12 +136,8 @@ struct ListenConfig:
135136
fn __init__(out self, keep_alive: Duration = default_tcp_keep_alive):
136137
self._keep_alive = keep_alive
137138

138-
fn listen[network: NetworkType, address_family: Int = AF_INET](mut self, address: String) raises -> NoTLSListener:
139+
fn listen[address_family: Int = AF_INET](mut self, address: String) raises -> NoTLSListener:
139140
constrained[address_family in [AF_INET, AF_INET6], "Address family must be either AF_INET or AF_INET6."]()
140-
constrained[
141-
network in NetworkType.SUPPORTED_TYPES,
142-
"Unsupported network type for internet address resolution. Unix addresses are not supported yet.",
143-
]()
144141
var local = parse_address(address)
145142
var addr = TCPAddr(local[0], local[1])
146143
var socket: Socket[TCPAddr]
@@ -196,18 +193,18 @@ struct ListenConfig:
196193
return listener^
197194

198195

199-
struct TCPConnection(Connection):
196+
struct TCPConnection:
200197
var socket: Socket[TCPAddr]
201198

202-
fn __init__(inout self, owned socket: Socket[TCPAddr]):
199+
fn __init__(out self, owned socket: Socket[TCPAddr]):
203200
self.socket = socket^
204201

205-
fn __moveinit__(inout self, owned existing: Self):
202+
fn __moveinit__(out self, owned existing: Self):
206203
self.socket = existing.socket^
207204

208205
fn read(self, mut buf: Bytes) raises -> Int:
209206
try:
210-
return self.socket.receive_into(buf)
207+
return self.socket.receive(buf)
211208
except e:
212209
if str(e) == "EOF":
213210
raise e
@@ -237,13 +234,101 @@ struct TCPConnection(Connection):
237234
fn is_closed(self) -> Bool:
238235
return self.socket._closed
239236

240-
fn local_addr(mut self) -> TCPAddr:
237+
# TODO: Switch to property or return ref when trait supports attributes.
238+
fn local_addr(self) -> TCPAddr:
241239
return self.socket.local_address()
242240

243241
fn remote_addr(self) -> TCPAddr:
244242
return self.socket.remote_address()
245243

246244

245+
struct UDPConnection:
246+
var socket: Socket[UDPAddr]
247+
248+
fn __init__(out self, owned socket: Socket[UDPAddr]):
249+
self.socket = socket^
250+
251+
fn __moveinit__(out self, owned existing: Self):
252+
self.socket = existing.socket^
253+
254+
fn read_from(mut self, size: Int = default_buffer_size) raises -> (Bytes, String, UInt16):
255+
"""Reads data from the underlying file descriptor.
256+
257+
Args:
258+
size: The size of the buffer to read data into.
259+
260+
Returns:
261+
The number of bytes read, or an error if one occurred.
262+
263+
Raises:
264+
Error: If an error occurred while reading data.
265+
"""
266+
return self.socket.receive_from(size)
267+
268+
fn read_from(mut self, mut dest: Bytes) raises -> (UInt, String, UInt16):
269+
"""Reads data from the underlying file descriptor.
270+
271+
Args:
272+
dest: The buffer to read data into.
273+
274+
Returns:
275+
The number of bytes read, or an error if one occurred.
276+
277+
Raises:
278+
Error: If an error occurred while reading data.
279+
"""
280+
return self.socket.receive_from(dest)
281+
282+
fn write_to(mut self, src: Span[Byte], address: UDPAddr) raises -> Int:
283+
"""Writes data to the underlying file descriptor.
284+
285+
Args:
286+
src: The buffer to read data into.
287+
address: The remote peer address.
288+
289+
Returns:
290+
The number of bytes written, or an error if one occurred.
291+
292+
Raises:
293+
Error: If an error occurred while writing data.
294+
"""
295+
return self.socket.send_to(src, address.ip, address.port)
296+
297+
fn write_to(mut self, src: Span[Byte], host: String, port: UInt16) raises -> Int:
298+
"""Writes data to the underlying file descriptor.
299+
300+
Args:
301+
src: The buffer to read data into.
302+
host: The remote peer address in IPv4 format.
303+
port: The remote peer port.
304+
305+
Returns:
306+
The number of bytes written, or an error if one occurred.
307+
308+
Raises:
309+
Error: If an error occurred while writing data.
310+
"""
311+
return self.socket.send_to(src, host, port)
312+
313+
fn close(mut self) raises:
314+
self.socket.close()
315+
316+
fn shutdown(mut self) raises:
317+
self.socket.shutdown()
318+
319+
fn teardown(mut self) raises:
320+
self.socket.teardown()
321+
322+
fn is_closed(self) -> Bool:
323+
return self.socket._closed
324+
325+
fn local_addr(self) -> ref [self.socket._local_address] UDPAddr:
326+
return self.socket.local_address()
327+
328+
fn remote_addr(self) -> ref [self.socket._remote_address] UDPAddr:
329+
return self.socket.remote_address()
330+
331+
247332
@value
248333
@register_passable("trivial")
249334
struct addrinfo_macos(AnAddrInfo):
@@ -261,12 +346,19 @@ struct addrinfo_macos(AnAddrInfo):
261346
var ai_addr: UnsafePointer[sockaddr]
262347
var ai_next: OpaquePointer
263348

264-
fn __init__(out self, ai_flags: c_int = 0, ai_family: c_int = 0, ai_socktype: c_int = 0, ai_protocol: c_int = 0):
265-
self.ai_flags = 0
266-
self.ai_family = 0
267-
self.ai_socktype = 0
268-
self.ai_protocol = 0
269-
self.ai_addrlen = 0
349+
fn __init__(
350+
out self,
351+
ai_flags: c_int = 0,
352+
ai_family: c_int = 0,
353+
ai_socktype: c_int = 0,
354+
ai_protocol: c_int = 0,
355+
ai_addrlen: socklen_t = 0,
356+
):
357+
self.ai_flags = ai_flags
358+
self.ai_family = ai_family
359+
self.ai_socktype = ai_socktype
360+
self.ai_protocol = ai_protocol
361+
self.ai_addrlen = ai_addrlen
270362
self.ai_canonname = UnsafePointer[c_char]()
271363
self.ai_addr = UnsafePointer[sockaddr]()
272364
self.ai_next = OpaquePointer()
@@ -314,12 +406,19 @@ struct addrinfo_unix(AnAddrInfo):
314406
var ai_canonname: UnsafePointer[c_char]
315407
var ai_next: OpaquePointer
316408

317-
fn __init__(out self, ai_flags: c_int = 0, ai_family: c_int = 0, ai_socktype: c_int = 0, ai_protocol: c_int = 0):
409+
fn __init__(
410+
out self,
411+
ai_flags: c_int = 0,
412+
ai_family: c_int = 0,
413+
ai_socktype: c_int = 0,
414+
ai_protocol: c_int = 0,
415+
ai_addrlen: socklen_t = 0,
416+
):
318417
self.ai_flags = ai_flags
319418
self.ai_family = ai_family
320419
self.ai_socktype = ai_socktype
321420
self.ai_protocol = ai_protocol
322-
self.ai_addrlen = 0
421+
self.ai_addrlen = ai_addrlen
323422
self.ai_addr = UnsafePointer[sockaddr]()
324423
self.ai_canonname = UnsafePointer[c_char]()
325424
self.ai_next = OpaquePointer()
@@ -395,10 +494,10 @@ struct TCPAddr(Addr):
395494
fn network(self) -> String:
396495
return NetworkType.tcp.value
397496

398-
fn __eq__(self, other: TCPAddr) -> Bool:
497+
fn __eq__(self, other: Self) -> Bool:
399498
return self.ip == other.ip and self.port == other.port and self.zone == other.zone
400499

401-
fn __ne__(self, other: TCPAddr) -> Bool:
500+
fn __ne__(self, other: Self) -> Bool:
402501
return not self == other
403502

404503
fn __str__(self) -> String:
@@ -413,6 +512,140 @@ struct TCPAddr(Addr):
413512
writer.write("TCPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")")
414513

415514

515+
@value
516+
struct UDPAddr(Addr):
517+
alias _type = "UDPAddr"
518+
var ip: String
519+
var port: UInt16
520+
var zone: String # IPv6 addressing zone
521+
522+
fn __init__(out self):
523+
self.ip = "127.0.0.1"
524+
self.port = 8000
525+
self.zone = ""
526+
527+
fn __init__(out self, ip: String = "127.0.0.1", port: UInt16 = 8000):
528+
self.ip = ip
529+
self.port = port
530+
self.zone = ""
531+
532+
fn network(self) -> String:
533+
return NetworkType.udp.value
534+
535+
fn __eq__(self, other: Self) -> Bool:
536+
return self.ip == other.ip and self.port == other.port and self.zone == other.zone
537+
538+
fn __ne__(self, other: Self) -> Bool:
539+
return not self == other
540+
541+
fn __str__(self) -> String:
542+
if self.zone != "":
543+
return join_host_port(self.ip + "%" + self.zone, str(self.port))
544+
return join_host_port(self.ip, str(self.port))
545+
546+
fn __repr__(self) -> String:
547+
return String.write(self)
548+
549+
fn write_to[W: Writer, //](self, mut writer: W):
550+
writer.write("UDPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")")
551+
552+
553+
fn listen_udp(local_address: UDPAddr) raises -> UDPConnection:
554+
"""Creates a new UDP listener.
555+
556+
Args:
557+
local_address: The local address to listen on.
558+
559+
Returns:
560+
A UDP connection.
561+
562+
Raises:
563+
Error: If the address is invalid or failed to bind the socket.
564+
"""
565+
socket = Socket[UDPAddr](socket_type=SOCK_DGRAM)
566+
socket.bind(local_address.ip, local_address.port)
567+
return UDPConnection(socket^)
568+
569+
570+
fn listen_udp(local_address: String) raises -> UDPConnection:
571+
"""Creates a new UDP listener.
572+
573+
Args:
574+
local_address: The address to listen on. The format is "host:port".
575+
576+
Returns:
577+
A UDP connection.
578+
579+
Raises:
580+
Error: If the address is invalid or failed to bind the socket.
581+
"""
582+
var address = parse_address(local_address)
583+
return listen_udp(UDPAddr(address[0], address[1]))
584+
585+
586+
fn listen_udp(host: String, port: UInt16) raises -> UDPConnection:
587+
"""Creates a new UDP listener.
588+
589+
Args:
590+
host: The address to listen on in ipv4 format.
591+
port: The port number.
592+
593+
Returns:
594+
A UDP connection.
595+
596+
Raises:
597+
Error: If the address is invalid or failed to bind the socket.
598+
"""
599+
return listen_udp(UDPAddr(host, port))
600+
601+
602+
fn dial_udp(local_address: UDPAddr) raises -> UDPConnection:
603+
"""Connects to the address on the named network. The network must be "udp", "udp4", or "udp6".
604+
605+
Args:
606+
local_address: The local address.
607+
608+
Returns:
609+
The UDP connection.
610+
611+
Raises:
612+
Error: If the network type is not supported or failed to connect to the address.
613+
"""
614+
return UDPConnection(Socket(local_address=local_address, socket_type=SOCK_DGRAM))
615+
616+
617+
fn dial_udp(local_address: String) raises -> UDPConnection:
618+
"""Connects to the address on the named network. The network must be "udp", "udp4", or "udp6".
619+
620+
Args:
621+
local_address: The local address.
622+
623+
Returns:
624+
The UDP connection.
625+
626+
Raises:
627+
Error: If the network type is not supported or failed to connect to the address.
628+
"""
629+
var address = parse_address(local_address)
630+
return dial_udp(UDPAddr(address[0], address[1]))
631+
632+
633+
fn dial_udp(host: String, port: UInt16) raises -> UDPConnection:
634+
"""Connects to the address on the named network. The network must be "udp", "udp4", or "udp6".
635+
636+
Args:
637+
host: The host to connect to.
638+
port: The port to connect on.
639+
640+
Returns:
641+
The UDP connection.
642+
643+
Raises:
644+
Error: If the network type is not supported or failed to connect to the address.
645+
"""
646+
return dial_udp(UDPAddr(host, port))
647+
648+
416649
# TODO: Support IPv6 long form.
417650
fn join_host_port(host: String, port: String) -> String:
418651
if host.find(":") != -1: # must be IPv6 literal

lightbug_http/server.mojo

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from lightbug_http.io.sync import Duration
33
from lightbug_http.io.bytes import Bytes, bytes
44
from lightbug_http.strings import NetworkType
55
from lightbug_http.utils import ByteReader, logger
6-
from lightbug_http.net import NoTLSListener, default_buffer_size, TCPConnection, ListenConfig, TCPAddr
6+
from lightbug_http.net import NoTLSListener, default_buffer_size, TCPConnection, ListenConfig
77
from lightbug_http.socket import Socket
88
from lightbug_http.http import HTTPRequest, encode
99
from lightbug_http.http.common_response import InternalError
@@ -92,7 +92,7 @@ struct Server(Movable):
9292
handler: An object that handles incoming HTTP requests.
9393
"""
9494
var net = ListenConfig()
95-
var listener = net.listen[NetworkType.tcp4](address)
95+
var listener = net.listen(address)
9696
self.set_address(address)
9797
self.serve(listener^, handler)
9898

0 commit comments

Comments
 (0)