Skip to content

Commit 4c57774

Browse files
hovaescodamian3031
andauthored
Map INTERVAL types to Python types
Co-authored-by: Damian Owsianny <[email protected]>
1 parent bac6ae7 commit 4c57774

File tree

2 files changed

+115
-8
lines changed

2 files changed

+115
-8
lines changed

tests/integration/test_types_integration.py

+75-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from zoneinfo import ZoneInfo
77

88
import pytest
9+
from dateutil.relativedelta import relativedelta
910

1011
import trino
1112
from tests.integration.conftest import trino_version
@@ -733,14 +734,80 @@ def create_timezone(timezone_str: str) -> tzinfo:
733734
return ZoneInfo(timezone_str)
734735

735736

736-
def test_interval(trino_connection):
737-
SqlTest(trino_connection) \
738-
.add_field(sql="CAST(null AS INTERVAL YEAR TO MONTH)", python=None) \
739-
.add_field(sql="CAST(null AS INTERVAL DAY TO SECOND)", python=None) \
740-
.add_field(sql="INTERVAL '3' MONTH", python='0-3') \
741-
.add_field(sql="INTERVAL '2' DAY", python='2 00:00:00.000') \
742-
.add_field(sql="INTERVAL '-2' DAY", python='-2 00:00:00.000') \
743-
.execute()
737+
def test_interval_year_to_month(trino_connection):
738+
(
739+
SqlTest(trino_connection)
740+
.add_field(
741+
sql="CAST(null AS INTERVAL YEAR TO MONTH)",
742+
python=None)
743+
.add_field(
744+
sql="INTERVAL '10' YEAR",
745+
python=relativedelta(years=10))
746+
.add_field(
747+
sql="INTERVAL '-5' YEAR",
748+
python=relativedelta(years=-5))
749+
.add_field(
750+
sql="INTERVAL '3' MONTH",
751+
python=relativedelta(months=3))
752+
.add_field(
753+
sql="INTERVAL '-18' MONTH",
754+
python=relativedelta(years=-1, months=-6))
755+
.add_field(
756+
sql="INTERVAL '30' MONTH",
757+
python=relativedelta(years=2, months=6))
758+
# max supported INTERVAL in Trino
759+
.add_field(
760+
sql="INTERVAL '178956970-7' YEAR TO MONTH",
761+
python=relativedelta(years=178956970, months=7))
762+
# min supported INTERVAL in Trino
763+
.add_field(
764+
sql="INTERVAL '-178956970-8' YEAR TO MONTH",
765+
python=relativedelta(years=-178956970, months=-8))
766+
).execute()
767+
768+
769+
def test_interval_day_to_second(trino_connection):
770+
(
771+
SqlTest(trino_connection)
772+
.add_field(
773+
sql="CAST(null AS INTERVAL DAY TO SECOND)",
774+
python=None)
775+
.add_field(
776+
sql="INTERVAL '2' DAY",
777+
python=timedelta(days=2))
778+
.add_field(
779+
sql="INTERVAL '-2' DAY",
780+
python=timedelta(days=-2))
781+
.add_field(
782+
sql="INTERVAL '-2' SECOND",
783+
python=timedelta(seconds=-2))
784+
.add_field(
785+
sql="INTERVAL '1 11:11:11.116555' DAY TO SECOND",
786+
python=timedelta(days=1, seconds=40271, microseconds=116000))
787+
.add_field(
788+
sql="INTERVAL '-5 23:59:57.000' DAY TO SECOND",
789+
python=timedelta(days=-6, seconds=3))
790+
.add_field(
791+
sql="INTERVAL '12 10:45' DAY TO MINUTE",
792+
python=timedelta(days=12, seconds=38700))
793+
.add_field(
794+
sql="INTERVAL '45:32.123' MINUTE TO SECOND",
795+
python=timedelta(seconds=2732, microseconds=123000))
796+
.add_field(
797+
sql="INTERVAL '32.123' SECOND",
798+
python=timedelta(seconds=32, microseconds=123000))
799+
# max supported timedelta in Python
800+
.add_field(
801+
sql="INTERVAL '999999999 23:59:59.999' DAY TO SECOND",
802+
python=timedelta(days=999999999, hours=23, minutes=59, seconds=59, milliseconds=999))
803+
# min supported timedelta in Python
804+
.add_field(
805+
sql="INTERVAL '-999999999' DAY",
806+
python=timedelta(days=-999999999))
807+
).execute()
808+
809+
SqlExpectFailureTest(trino_connection).execute("INTERVAL '1000000000' DAY")
810+
SqlExpectFailureTest(trino_connection).execute("INTERVAL '-999999999 00:00:00.001' DAY TO SECOND")
744811

745812

746813
def test_array(trino_connection):

trino/mapper.py

+40
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
99
from zoneinfo import ZoneInfo
1010

11+
from dateutil.relativedelta import relativedelta
12+
1113
import trino.exceptions
1214
from trino.types import (
1315
POWERS_OF_TEN,
@@ -167,6 +169,40 @@ def _fraction_to_decimal(fractional_str: str) -> Decimal:
167169
return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)]
168170

169171

172+
class IntervalYearToMonthMapper(ValueMapper[relativedelta]):
173+
def map(self, value: Any) -> Optional[relativedelta]:
174+
if value is None:
175+
return None
176+
is_negative = value[0] == "-"
177+
years, months = (value[1:] if is_negative else value).split('-')
178+
years, months = int(years), int(months)
179+
if is_negative:
180+
years, months = -years, -months
181+
return relativedelta(years=years, months=months)
182+
183+
184+
class IntervalDayToSecondMapper(ValueMapper[timedelta]):
185+
def map(self, value: Any) -> Optional[timedelta]:
186+
if value is None:
187+
return None
188+
is_negative = value[0] == "-"
189+
days, time = (value[1:] if is_negative else value).split(' ')
190+
hours, minutes, seconds_milliseconds = time.split(':')
191+
seconds, milliseconds = seconds_milliseconds.split('.')
192+
days, hours, minutes, seconds, milliseconds = (int(days), int(hours), int(minutes), int(seconds),
193+
int(milliseconds))
194+
if is_negative:
195+
days, hours, minutes, seconds, milliseconds = -days, -hours, -minutes, -seconds, -milliseconds
196+
try:
197+
return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds, milliseconds=milliseconds)
198+
except OverflowError as e:
199+
error_str = (
200+
f"Could not convert '{value}' into the associated python type, as the value "
201+
"exceeds the maximum or minimum limit."
202+
)
203+
raise trino.exceptions.TrinoDataError(error_str) from e
204+
205+
170206
class ArrayValueMapper(ValueMapper[List[Optional[Any]]]):
171207
def __init__(self, mapper: ValueMapper[Any]):
172208
self.mapper = mapper
@@ -271,6 +307,10 @@ def _create_value_mapper(self, column: Dict[str, Any]) -> ValueMapper[Any]:
271307
return TimestampValueMapper(self._get_precision(column))
272308
if col_type == 'timestamp with time zone':
273309
return TimestampWithTimeZoneValueMapper(self._get_precision(column))
310+
if col_type == 'interval year to month':
311+
return IntervalYearToMonthMapper()
312+
if col_type == 'interval day to second':
313+
return IntervalDayToSecondMapper()
274314

275315
# structural types
276316
if col_type == 'array':

0 commit comments

Comments
 (0)