From 8a9f8c4f37b8a0502de515f2e7e07f6fa2c487eb Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Mon, 11 Mar 2024 10:51:06 +0000 Subject: [PATCH] Allowed floats in timedelta --- stdlib_extensions/builtins/__init__.mojo | 2 +- stdlib_extensions/builtins/_math.mojo | 5 + stdlib_extensions/datetime/_timedelta.mojo | 219 +++++++++++++-------- 3 files changed, 146 insertions(+), 80 deletions(-) diff --git a/stdlib_extensions/builtins/__init__.mojo b/stdlib_extensions/builtins/__init__.mojo index fdd96bf..7ceb4f8 100644 --- a/stdlib_extensions/builtins/__init__.mojo +++ b/stdlib_extensions/builtins/__init__.mojo @@ -3,7 +3,7 @@ from ._bytes import bytes, to_bytes from ..syscalls.filesystem import read_from_stdin from ._hash import custom_hash from ._types import Optional -from ._math import divmod +from ._math import divmod, modf from ._custom_equality import ___eq__ diff --git a/stdlib_extensions/builtins/_math.mojo b/stdlib_extensions/builtins/_math.mojo index 0d53e03..8153d2c 100644 --- a/stdlib_extensions/builtins/_math.mojo +++ b/stdlib_extensions/builtins/_math.mojo @@ -4,3 +4,8 @@ fn divmod(a: Int, b: Int) -> Tuple[Int, Int]: fn divmod(a: Int64, b: Int64) -> Tuple[Int64, Int64]: return a // b, a % b + + +fn modf(x: Float64) -> Tuple[Float64, Float64]: + var floor = math.trunc(x) + return (x - floor, floor) diff --git a/stdlib_extensions/datetime/_timedelta.mojo b/stdlib_extensions/datetime/_timedelta.mojo index 736f03a..9933069 100644 --- a/stdlib_extensions/datetime/_timedelta.mojo +++ b/stdlib_extensions/datetime/_timedelta.mojo @@ -1,12 +1,68 @@ -from ..builtins import divmod, list +from ..builtins import divmod, list, modf from ..builtins.string import rjust, join from ..builtins._generic_list import _cmp_list from ..builtins import custom_hash from utils.variant import Variant from math import abs, round +from .._utils import custom_debug_assert # TODO: use this in the timedelta constructor -alias IntOrFloat = Variant[Int, Float64] +alias IntOrFloatVariant = Variant[Int, Float64] + + +struct IntOrFloat: + """We define this to be able to do operation without worrying too much about type conversions. + """ + + var value: IntOrFloatVariant + + fn __init__(inout self, value: IntOrFloatVariant): + self.value = value + + fn __init__(inout self, value: Int): + self.value = value + + fn __init__(inout self, value: Float64): + self.value = value + + fn __init__(inout self, value: Float32): + self.value = IntOrFloatVariant(value.cast[DType.float64]()) + + fn to_float(self) -> Float64: + if self.value.isa[Int](): + return Float64(self.value.get[Int]()[]) + else: + return self.value.get[Float64]()[] + + fn to_int(self) -> Int: + custom_debug_assert(self.value.isa[Int](), "We should have an int here") + return self.value.get[Int]()[] + + fn isfloat(self) -> Bool: + return self.value.isa[Float64]() + + fn isint(self) -> Bool: + return self.value.isa[Int]() + + fn __mul__(self, other: Int) -> IntOrFloat: + if self.value.isa[Int](): + return IntOrFloat(self.value.get[Int]()[] * other) + else: + return IntOrFloat(self.value.get[Float64]()[] * Float64(other)) + + fn __iadd__(inout self, other: IntOrFloat): + if self.value.isa[Int]() and other.value.isa[Int](): + self.value.set[Int](self.value.get[Int]()[] + other.value.get[Int]()[]) + else: + # we upgrade to float + self.value.set[Float64](self.to_float() + other.to_float()) + + fn __add__(self, other: IntOrFloat) -> IntOrFloat: + if self.value.isa[Int]() and other.value.isa[Int](): + return IntOrFloat(self.value.get[Int]()[] + other.value.get[Int]()[]) + else: + # we upgrade to float + return IntOrFloat(self.to_float() + other.to_float()) @value @@ -44,13 +100,13 @@ struct timedelta(CollectionElement, Stringable, Hashable): fn __init__( inout self, - owned days: Int = 0, - owned seconds: Int = 0, - owned microseconds: Int = 0, - milliseconds: Int = 0, - minutes: Int = 0, - hours: Int = 0, - weeks: Int = 0, + owned days: IntOrFloat = 0, + owned seconds: IntOrFloat = 0, + owned microseconds: IntOrFloat = 0, + milliseconds: IntOrFloat = 0, + minutes: IntOrFloat = 0, + hours: IntOrFloat = 0, + weeks: IntOrFloat = 0, ): # Doing this efficiently and accurately in C is going to be difficult # and error-prone, due to ubiquitous overflow possibilities, and that @@ -64,9 +120,9 @@ struct timedelta(CollectionElement, Stringable, Hashable): # Final values, all integer. # s and us fit in 32-bit signed ints; d isn't bounded. - var d = 0 - var s = 0 - var us = 0 + var d: Int = 0 + var s: Int = 0 + var us: Int = 0 # Normalize everything to days, seconds, microseconds. days += weeks * 7 @@ -75,83 +131,88 @@ struct timedelta(CollectionElement, Stringable, Hashable): # Get rid of all fractions, and normalize s and us. # Take a deep breath . - # if isinstance(days, float): - # dayfrac, days = _math.modf(days) - # daysecondsfrac, daysecondswhole = _math.modf(dayfrac * (24.*3600.)) - # assert daysecondswhole == int(daysecondswhole) # can't overflow - # s = int(daysecondswhole) - # assert days == int(days) - # d = int(days) - # else: - var daysecondsfrac = 0.0 - d = days - # TODO: manage floats - - # assert isinstance(daysecondsfrac, float) - # assert abs(daysecondsfrac) <= 1.0 - # assert isinstance(d, int) - # assert abs(s) <= 24 * 3600 - # days isn't referenced again before redefinition + var daysecondsfrac: Float64 + if days.isfloat(): + var dayfrac: Float64 + var days_as_int_still_float: Float64 + dayfrac, days_as_int_still_float = modf(days.to_float()) + var daysecondswhole: Float64 + daysecondsfrac, daysecondswhole = modf(dayfrac * (24.0 * 3600.0)) + s = int(daysecondswhole) + d = int(days_as_int_still_float) + else: + daysecondsfrac = 0.0 + d = days.to_int() - # if isinstance(seconds, float): - # secondsfrac, seconds = _math.modf(seconds) - # assert seconds == int(seconds) - # seconds = int(seconds) - # secondsfrac += daysecondsfrac - # assert abs(secondsfrac) <= 2.0 - # else: - var secondsfrac = daysecondsfrac - # TODO: Manage floats + custom_debug_assert(abs(daysecondsfrac) <= 1.0) + custom_debug_assert(abs(s) <= 24 * 3600) + # days isn't referenced again before redefinition + var secondsfrac: Float64 + + if seconds.isfloat(): + var seconds_as_int_still_float: Float64 + secondsfrac, seconds_as_int_still_float = modf(seconds.to_float()) + seconds = int(seconds_as_int_still_float) + secondsfrac += daysecondsfrac + custom_debug_assert(abs(secondsfrac) <= 2.0) + else: + secondsfrac = daysecondsfrac # daysecondsfrac isn't referenced again - # assert isinstance(secondsfrac, float) - # assert abs(secondsfrac) <= 2.0 - - # assert isinstance(seconds, int) - days, seconds = divmod(seconds, 24 * 3600) - d += days - s += int(seconds) # can't overflow - # assert isinstance(s, int) - # assert abs(s) <= 2 * 24 * 3600 + custom_debug_assert(abs(secondsfrac) <= 2.0) + + custom_debug_assert(seconds.isint()) + var additional_days: Int + var additional_seconds: Int + additional_days, additional_seconds = divmod(seconds.to_int(), 24 * 3600) + d += additional_days + s += additional_seconds # can't overflow + custom_debug_assert(abs(s) <= 2 * 24 * 3600) # seconds isn't referenced again before redefinition var usdouble = secondsfrac * 1e6 - # assert abs(usdouble) < 2.1e6 # exact value not critical + custom_debug_assert(abs(usdouble) < 2.1e6) # exact value not critical # secondsfrac isn't referenced again - - # if isinstance(microseconds, float): - # microseconds = round(microseconds + usdouble) - # seconds, microseconds = divmod(microseconds, 1000000) - # days, seconds = divmod(seconds, 24*3600) - # d += days - # s += seconds - # else: - microseconds = int(microseconds) - seconds, microseconds = divmod(microseconds, 1000000) - days, seconds = divmod(seconds, 24 * 3600) - d += days - s += seconds - microseconds = round(Float64(microseconds) + usdouble).to_int() + var additional_microseconds: Int + if microseconds.isfloat(): + var microseconds_as_int = int(round(microseconds.to_float() + usdouble)) + var additional_seconds: Int + var additional_days: Int + additional_seconds, additional_microseconds = divmod( + microseconds_as_int, 1000000 + ) + additional_days, additional_seconds = divmod(additional_seconds, 24 * 3600) + d += additional_days + s += additional_seconds + else: + var additional_seconds: Int + var additional_days: Int + additional_microseconds = microseconds.to_int() + additional_seconds, additional_microseconds = divmod( + additional_microseconds, 1000000 + ) + additional_days, additional_seconds = divmod(additional_seconds, 24 * 3600) + d += additional_days + s += additional_seconds + additional_microseconds = round(additional_microseconds + usdouble).to_int() # TODO: Manage floats - # assert isinstance(s, int) - # assert isinstance(microseconds, int) - # assert abs(s) <= 3 * 24 * 3600 - # assert abs(microseconds) < 3.1e6 + custom_debug_assert(abs(s) <= 3 * 24 * 3600) + custom_debug_assert(abs(additional_microseconds) < 3100000) # Just a little bit of carrying possible for microseconds and seconds. - seconds, us = divmod(microseconds, 1000000) - s += seconds - days, s = divmod(s, 24 * 3600) - d += days - - # assert isinstance(d, int) - # assert isinstance(s, int) and 0 <= s < 24 * 3600 - # assert isinstance(us, int) and 0 <= us < 1000000 - - if abs(d) > 999999999: - pass - # raise OverflowError("timedelta # of days is too large: %d" % d) - + var additional_seconds2: Int + additional_seconds2, us = divmod(additional_microseconds, 1000000) + s += additional_seconds2 + var additional_days2: Int + additional_days2, s = divmod(s, 24 * 3600) + d += additional_days2 + + custom_debug_assert(0 <= s < 24 * 3600) + custom_debug_assert(0 <= us < 1000000) + + custom_debug_assert( + abs(d) < 999999999, "timedelta # of days is too large: " + str(d) + ) self.days = d self.seconds = s self.microseconds = us