|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import os |
3 | 4 | from datetime import datetime
|
4 | 5 | from typing import Any
|
| 6 | +from unittest import mock |
5 | 7 |
|
6 | 8 | import polars
|
7 | 9 | import pytest
|
@@ -71,63 +73,62 @@ def test_q1(df_raw: Any) -> None:
|
71 | 73 | compare_dicts(result, expected)
|
72 | 74 |
|
73 | 75 |
|
74 |
| -# @pytest.mark.parametrize( |
75 |
| -# "df_raw", |
76 |
| -# [ |
77 |
| -# (polars.read_parquet("tests/data/lineitem.parquet").to_pandas()), |
78 |
| -# ], |
79 |
| -# ) |
80 |
| -# @mock.patch.dict(os.environ, {"NARWHALS_FORCE_GENERIC": "1"}) |
81 |
| -# def test_q1_w_pandas_agg_generic_path(df_raw: Any) -> None: |
82 |
| -# var_1 = datetime(1998, 9, 2) |
83 |
| -# df = translate_frame(df_raw, is_lazy=True) |
84 |
| -# pl = get_namespace(df) |
85 |
| -# query_result = ( |
86 |
| -# df.filter(pl.col("l_shipdate") <= var_1) |
87 |
| -# .group_by(["l_returnflag", "l_linestatus"]) |
88 |
| -# .agg( |
89 |
| -# [ |
90 |
| -# pl.sum("l_quantity").alias("sum_qty"), |
91 |
| -# pl.sum("l_extendedprice").alias("sum_base_price"), |
92 |
| -# (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))) |
93 |
| -# .sum() |
94 |
| -# .alias("sum_disc_price"), |
95 |
| -# ( |
96 |
| -# pl.col("l_extendedprice") |
97 |
| -# * (1.0 - pl.col("l_discount")) |
98 |
| -# * (1.0 + pl.col("l_tax")) |
99 |
| -# ) |
100 |
| -# .sum() |
101 |
| -# .alias("sum_charge"), |
102 |
| -# pl.mean("l_quantity").alias("avg_qty"), |
103 |
| -# pl.mean("l_extendedprice").alias("avg_price"), |
104 |
| -# pl.mean("l_discount").alias("avg_disc"), |
105 |
| -# pl.len().alias("count_order"), |
106 |
| -# ], |
107 |
| -# ) |
108 |
| -# .sort(["l_returnflag", "l_linestatus"]) |
109 |
| -# ) |
110 |
| -# result = query_result.collect().to_dict(as_series=False) |
111 |
| -# expected = { |
112 |
| -# "l_returnflag": ["A", "N", "N", "R"], |
113 |
| -# "l_linestatus": ["F", "F", "O", "F"], |
114 |
| -# "sum_qty": [2109.0, 29.0, 3682.0, 1876.0], |
115 |
| -# "sum_base_price": [3114026.44, 39824.83, 5517101.99, 2947892.16], |
116 |
| -# "sum_disc_price": [2954950.8082, 39028.3334, 5205468.4852, 2816542.4816999994], |
117 |
| -# "sum_charge": [ |
118 |
| -# 3092840.4194289995, |
119 |
| -# 39808.900068, |
120 |
| -# 5406966.873419, |
121 |
| -# 2935797.8313019997, |
122 |
| -# ], |
123 |
| -# "avg_qty": [27.75, 29.0, 25.047619047619047, 26.422535211267604], |
124 |
| -# "avg_price": [ |
125 |
| -# 40974.032105263155, |
126 |
| -# 39824.83, |
127 |
| -# 37531.30605442177, |
128 |
| -# 41519.607887323946, |
129 |
| -# ], |
130 |
| -# "avg_disc": [0.05039473684210526, 0.02, 0.05537414965986395, 0.04507042253521127], |
131 |
| -# "count_order": [76, 1, 147, 71], |
132 |
| -# } |
133 |
| -# compare_dicts(result, expected) |
| 76 | +@pytest.mark.parametrize( |
| 77 | + "df_raw", |
| 78 | + [ |
| 79 | + (polars.read_parquet("tests/data/lineitem.parquet").to_pandas()), |
| 80 | + ], |
| 81 | +) |
| 82 | +@mock.patch.dict(os.environ, {"NARWHALS_FORCE_GENERIC": "1"}) |
| 83 | +def test_q1_w_pandas_agg_generic_path(df_raw: Any) -> None: |
| 84 | + var_1 = datetime(1998, 9, 2) |
| 85 | + df = nw.DataFrame(df_raw, is_lazy=True) |
| 86 | + query_result = ( |
| 87 | + df.filter(nw.col("l_shipdate") <= var_1) |
| 88 | + .group_by(["l_returnflag", "l_linestatus"]) |
| 89 | + .agg( |
| 90 | + [ |
| 91 | + nw.sum("l_quantity").alias("sum_qty"), |
| 92 | + nw.sum("l_extendedprice").alias("sum_base_price"), |
| 93 | + (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))) |
| 94 | + .sum() |
| 95 | + .alias("sum_disc_price"), |
| 96 | + ( |
| 97 | + nw.col("l_extendedprice") |
| 98 | + * (1.0 - nw.col("l_discount")) |
| 99 | + * (1.0 + nw.col("l_tax")) |
| 100 | + ) |
| 101 | + .sum() |
| 102 | + .alias("sum_charge"), |
| 103 | + nw.mean("l_quantity").alias("avg_qty"), |
| 104 | + nw.mean("l_extendedprice").alias("avg_price"), |
| 105 | + nw.mean("l_discount").alias("avg_disc"), |
| 106 | + nw.len().alias("count_order"), |
| 107 | + ], |
| 108 | + ) |
| 109 | + .sort(["l_returnflag", "l_linestatus"]) |
| 110 | + ) |
| 111 | + result = query_result.collect().to_dict(as_series=False) |
| 112 | + expected = { |
| 113 | + "l_returnflag": ["A", "N", "N", "R"], |
| 114 | + "l_linestatus": ["F", "F", "O", "F"], |
| 115 | + "sum_qty": [2109.0, 29.0, 3682.0, 1876.0], |
| 116 | + "sum_base_price": [3114026.44, 39824.83, 5517101.99, 2947892.16], |
| 117 | + "sum_disc_price": [2954950.8082, 39028.3334, 5205468.4852, 2816542.4816999994], |
| 118 | + "sum_charge": [ |
| 119 | + 3092840.4194289995, |
| 120 | + 39808.900068, |
| 121 | + 5406966.873419, |
| 122 | + 2935797.8313019997, |
| 123 | + ], |
| 124 | + "avg_qty": [27.75, 29.0, 25.047619047619047, 26.422535211267604], |
| 125 | + "avg_price": [ |
| 126 | + 40974.032105263155, |
| 127 | + 39824.83, |
| 128 | + 37531.30605442177, |
| 129 | + 41519.607887323946, |
| 130 | + ], |
| 131 | + "avg_disc": [0.05039473684210526, 0.02, 0.05537414965986395, 0.04507042253521127], |
| 132 | + "count_order": [76, 1, 147, 71], |
| 133 | + } |
| 134 | + compare_dicts(result, expected) |
0 commit comments