Skip to content
This repository was archived by the owner on May 29, 2024. It is now read-only.

Commit 8a9f8c4

Browse files
Allowed floats in timedelta
1 parent 3dd32ff commit 8a9f8c4

File tree

3 files changed

+146
-80
lines changed

3 files changed

+146
-80
lines changed

stdlib_extensions/builtins/__init__.mojo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from ._bytes import bytes, to_bytes
33
from ..syscalls.filesystem import read_from_stdin
44
from ._hash import custom_hash
55
from ._types import Optional
6-
from ._math import divmod
6+
from ._math import divmod, modf
77
from ._custom_equality import ___eq__
88

99

stdlib_extensions/builtins/_math.mojo

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,8 @@ fn divmod(a: Int, b: Int) -> Tuple[Int, Int]:
44

55
fn divmod(a: Int64, b: Int64) -> Tuple[Int64, Int64]:
66
return a // b, a % b
7+
8+
9+
fn modf(x: Float64) -> Tuple[Float64, Float64]:
10+
var floor = math.trunc(x)
11+
return (x - floor, floor)

stdlib_extensions/datetime/_timedelta.mojo

Lines changed: 140 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,68 @@
1-
from ..builtins import divmod, list
1+
from ..builtins import divmod, list, modf
22
from ..builtins.string import rjust, join
33
from ..builtins._generic_list import _cmp_list
44
from ..builtins import custom_hash
55
from utils.variant import Variant
66
from math import abs, round
7+
from .._utils import custom_debug_assert
78

89
# TODO: use this in the timedelta constructor
9-
alias IntOrFloat = Variant[Int, Float64]
10+
alias IntOrFloatVariant = Variant[Int, Float64]
11+
12+
13+
struct IntOrFloat:
14+
"""We define this to be able to do operation without worrying too much about type conversions.
15+
"""
16+
17+
var value: IntOrFloatVariant
18+
19+
fn __init__(inout self, value: IntOrFloatVariant):
20+
self.value = value
21+
22+
fn __init__(inout self, value: Int):
23+
self.value = value
24+
25+
fn __init__(inout self, value: Float64):
26+
self.value = value
27+
28+
fn __init__(inout self, value: Float32):
29+
self.value = IntOrFloatVariant(value.cast[DType.float64]())
30+
31+
fn to_float(self) -> Float64:
32+
if self.value.isa[Int]():
33+
return Float64(self.value.get[Int]()[])
34+
else:
35+
return self.value.get[Float64]()[]
36+
37+
fn to_int(self) -> Int:
38+
custom_debug_assert(self.value.isa[Int](), "We should have an int here")
39+
return self.value.get[Int]()[]
40+
41+
fn isfloat(self) -> Bool:
42+
return self.value.isa[Float64]()
43+
44+
fn isint(self) -> Bool:
45+
return self.value.isa[Int]()
46+
47+
fn __mul__(self, other: Int) -> IntOrFloat:
48+
if self.value.isa[Int]():
49+
return IntOrFloat(self.value.get[Int]()[] * other)
50+
else:
51+
return IntOrFloat(self.value.get[Float64]()[] * Float64(other))
52+
53+
fn __iadd__(inout self, other: IntOrFloat):
54+
if self.value.isa[Int]() and other.value.isa[Int]():
55+
self.value.set[Int](self.value.get[Int]()[] + other.value.get[Int]()[])
56+
else:
57+
# we upgrade to float
58+
self.value.set[Float64](self.to_float() + other.to_float())
59+
60+
fn __add__(self, other: IntOrFloat) -> IntOrFloat:
61+
if self.value.isa[Int]() and other.value.isa[Int]():
62+
return IntOrFloat(self.value.get[Int]()[] + other.value.get[Int]()[])
63+
else:
64+
# we upgrade to float
65+
return IntOrFloat(self.to_float() + other.to_float())
1066

1167

1268
@value
@@ -44,13 +100,13 @@ struct timedelta(CollectionElement, Stringable, Hashable):
44100

45101
fn __init__(
46102
inout self,
47-
owned days: Int = 0,
48-
owned seconds: Int = 0,
49-
owned microseconds: Int = 0,
50-
milliseconds: Int = 0,
51-
minutes: Int = 0,
52-
hours: Int = 0,
53-
weeks: Int = 0,
103+
owned days: IntOrFloat = 0,
104+
owned seconds: IntOrFloat = 0,
105+
owned microseconds: IntOrFloat = 0,
106+
milliseconds: IntOrFloat = 0,
107+
minutes: IntOrFloat = 0,
108+
hours: IntOrFloat = 0,
109+
weeks: IntOrFloat = 0,
54110
):
55111
# Doing this efficiently and accurately in C is going to be difficult
56112
# and error-prone, due to ubiquitous overflow possibilities, and that
@@ -64,9 +120,9 @@ struct timedelta(CollectionElement, Stringable, Hashable):
64120

65121
# Final values, all integer.
66122
# s and us fit in 32-bit signed ints; d isn't bounded.
67-
var d = 0
68-
var s = 0
69-
var us = 0
123+
var d: Int = 0
124+
var s: Int = 0
125+
var us: Int = 0
70126

71127
# Normalize everything to days, seconds, microseconds.
72128
days += weeks * 7
@@ -75,83 +131,88 @@ struct timedelta(CollectionElement, Stringable, Hashable):
75131

76132
# Get rid of all fractions, and normalize s and us.
77133
# Take a deep breath <wink>.
78-
# if isinstance(days, float):
79-
# dayfrac, days = _math.modf(days)
80-
# daysecondsfrac, daysecondswhole = _math.modf(dayfrac * (24.*3600.))
81-
# assert daysecondswhole == int(daysecondswhole) # can't overflow
82-
# s = int(daysecondswhole)
83-
# assert days == int(days)
84-
# d = int(days)
85-
# else:
86-
var daysecondsfrac = 0.0
87-
d = days
88-
# TODO: manage floats
89-
90-
# assert isinstance(daysecondsfrac, float)
91-
# assert abs(daysecondsfrac) <= 1.0
92-
# assert isinstance(d, int)
93-
# assert abs(s) <= 24 * 3600
94-
# days isn't referenced again before redefinition
134+
var daysecondsfrac: Float64
135+
if days.isfloat():
136+
var dayfrac: Float64
137+
var days_as_int_still_float: Float64
138+
dayfrac, days_as_int_still_float = modf(days.to_float())
139+
var daysecondswhole: Float64
140+
daysecondsfrac, daysecondswhole = modf(dayfrac * (24.0 * 3600.0))
141+
s = int(daysecondswhole)
142+
d = int(days_as_int_still_float)
143+
else:
144+
daysecondsfrac = 0.0
145+
d = days.to_int()
95146

96-
# if isinstance(seconds, float):
97-
# secondsfrac, seconds = _math.modf(seconds)
98-
# assert seconds == int(seconds)
99-
# seconds = int(seconds)
100-
# secondsfrac += daysecondsfrac
101-
# assert abs(secondsfrac) <= 2.0
102-
# else:
103-
var secondsfrac = daysecondsfrac
104-
# TODO: Manage floats
147+
custom_debug_assert(abs(daysecondsfrac) <= 1.0)
148+
custom_debug_assert(abs(s) <= 24 * 3600)
149+
# days isn't referenced again before redefinition
150+
var secondsfrac: Float64
151+
152+
if seconds.isfloat():
153+
var seconds_as_int_still_float: Float64
154+
secondsfrac, seconds_as_int_still_float = modf(seconds.to_float())
155+
seconds = int(seconds_as_int_still_float)
156+
secondsfrac += daysecondsfrac
157+
custom_debug_assert(abs(secondsfrac) <= 2.0)
158+
else:
159+
secondsfrac = daysecondsfrac
105160

106161
# daysecondsfrac isn't referenced again
107-
# assert isinstance(secondsfrac, float)
108-
# assert abs(secondsfrac) <= 2.0
109-
110-
# assert isinstance(seconds, int)
111-
days, seconds = divmod(seconds, 24 * 3600)
112-
d += days
113-
s += int(seconds) # can't overflow
114-
# assert isinstance(s, int)
115-
# assert abs(s) <= 2 * 24 * 3600
162+
custom_debug_assert(abs(secondsfrac) <= 2.0)
163+
164+
custom_debug_assert(seconds.isint())
165+
var additional_days: Int
166+
var additional_seconds: Int
167+
additional_days, additional_seconds = divmod(seconds.to_int(), 24 * 3600)
168+
d += additional_days
169+
s += additional_seconds # can't overflow
170+
custom_debug_assert(abs(s) <= 2 * 24 * 3600)
116171
# seconds isn't referenced again before redefinition
117172

118173
var usdouble = secondsfrac * 1e6
119-
# assert abs(usdouble) < 2.1e6 # exact value not critical
174+
custom_debug_assert(abs(usdouble) < 2.1e6) # exact value not critical
120175
# secondsfrac isn't referenced again
121-
122-
# if isinstance(microseconds, float):
123-
# microseconds = round(microseconds + usdouble)
124-
# seconds, microseconds = divmod(microseconds, 1000000)
125-
# days, seconds = divmod(seconds, 24*3600)
126-
# d += days
127-
# s += seconds
128-
# else:
129-
microseconds = int(microseconds)
130-
seconds, microseconds = divmod(microseconds, 1000000)
131-
days, seconds = divmod(seconds, 24 * 3600)
132-
d += days
133-
s += seconds
134-
microseconds = round(Float64(microseconds) + usdouble).to_int()
176+
var additional_microseconds: Int
177+
if microseconds.isfloat():
178+
var microseconds_as_int = int(round(microseconds.to_float() + usdouble))
179+
var additional_seconds: Int
180+
var additional_days: Int
181+
additional_seconds, additional_microseconds = divmod(
182+
microseconds_as_int, 1000000
183+
)
184+
additional_days, additional_seconds = divmod(additional_seconds, 24 * 3600)
185+
d += additional_days
186+
s += additional_seconds
187+
else:
188+
var additional_seconds: Int
189+
var additional_days: Int
190+
additional_microseconds = microseconds.to_int()
191+
additional_seconds, additional_microseconds = divmod(
192+
additional_microseconds, 1000000
193+
)
194+
additional_days, additional_seconds = divmod(additional_seconds, 24 * 3600)
195+
d += additional_days
196+
s += additional_seconds
197+
additional_microseconds = round(additional_microseconds + usdouble).to_int()
135198
# TODO: Manage floats
136-
# assert isinstance(s, int)
137-
# assert isinstance(microseconds, int)
138-
# assert abs(s) <= 3 * 24 * 3600
139-
# assert abs(microseconds) < 3.1e6
199+
custom_debug_assert(abs(s) <= 3 * 24 * 3600)
200+
custom_debug_assert(abs(additional_microseconds) < 3100000)
140201

141202
# Just a little bit of carrying possible for microseconds and seconds.
142-
seconds, us = divmod(microseconds, 1000000)
143-
s += seconds
144-
days, s = divmod(s, 24 * 3600)
145-
d += days
146-
147-
# assert isinstance(d, int)
148-
# assert isinstance(s, int) and 0 <= s < 24 * 3600
149-
# assert isinstance(us, int) and 0 <= us < 1000000
150-
151-
if abs(d) > 999999999:
152-
pass
153-
# raise OverflowError("timedelta # of days is too large: %d" % d)
154-
203+
var additional_seconds2: Int
204+
additional_seconds2, us = divmod(additional_microseconds, 1000000)
205+
s += additional_seconds2
206+
var additional_days2: Int
207+
additional_days2, s = divmod(s, 24 * 3600)
208+
d += additional_days2
209+
210+
custom_debug_assert(0 <= s < 24 * 3600)
211+
custom_debug_assert(0 <= us < 1000000)
212+
213+
custom_debug_assert(
214+
abs(d) < 999999999, "timedelta # of days is too large: " + str(d)
215+
)
155216
self.days = d
156217
self.seconds = s
157218
self.microseconds = us

0 commit comments

Comments
 (0)