Skip to content

Commit 5c5b9c1

Browse files
authored
Merge pull request #738 from ClickHouse/revert-736-fix_polars
Revert "fix Polars queries"
2 parents d0348de + 0ddb250 commit 5c5b9c1

2 files changed

Lines changed: 32 additions & 53 deletions

File tree

polars-dataframe/query.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212

1313
# 0: No., 1: SQL, 2: Polars
1414
queries = [
15-
("Q0", "SELECT COUNT(*) FROM hits;", lambda x: x.select(pl.len()).collect().item()),
15+
("Q0", "SELECT COUNT(*) FROM hits;", lambda x: x.select(pl.len()).collect().height),
1616
(
1717
"Q1",
1818
"SELECT COUNT(*) FROM hits WHERE AdvEngineID <> 0;",
19-
lambda x: x.filter(pl.col("AdvEngineID") != 0).select(pl.len()).collect().item(),
19+
lambda x: x.select(pl.col("AdvEngineID").filter(pl.col("AdvEngineID") != 0).count()).collect().height,
2020
),
2121
(
2222
"Q2",
2323
"SELECT SUM(AdvEngineID), COUNT(*), AVG(ResolutionWidth) FROM hits;",
24-
lambda x: x.select(a_sum=pl.col("AdvEngineID").sum(), count=pl.len(), a_mean=pl.col("ResolutionWidth").mean()).collect().rows()[0],
24+
lambda x: x.select(a_sum=pl.col("AdvEngineID").sum(), count=pl.len(), a_mean=pl.col("AdvEngineID").mean()).collect().rows()[0],
2525
),
2626
(
2727
"Q3",
@@ -55,8 +55,8 @@
5555
"Q8",
5656
"SELECT RegionID, COUNT(DISTINCT UserID) AS u FROM hits GROUP BY RegionID ORDER BY u DESC LIMIT 10;",
5757
lambda x: x.group_by("RegionID")
58-
.agg(pl.col("UserID").n_unique().alias("u"))
59-
.sort("u", descending=True)
58+
.agg(pl.col("UserID").n_unique())
59+
.sort("UserID", descending=True)
6060
.head(10).collect(),
6161
),
6262
(
@@ -66,12 +66,11 @@
6666
.agg(
6767
[
6868
pl.sum("AdvEngineID").alias("AdvEngineID_sum"),
69-
pl.len().alias("count"),
7069
pl.mean("ResolutionWidth").alias("ResolutionWidth_mean"),
7170
pl.col("UserID").n_unique().alias("UserID_nunique"),
7271
]
7372
)
74-
.sort("count", descending=True)
73+
.sort("AdvEngineID_sum", descending=True)
7574
.head(10).collect(),
7675
),
7776
(
@@ -244,7 +243,7 @@
244243
x.filter(pl.col("Referer") != "")
245244
.with_columns(
246245
pl.col("Referer")
247-
.str.extract("(?-u)^https?://(?:www\\.)?([^/]+)/.*$")
246+
.str.extract(r"^https?://(?:www\\.)?([^/]+)/.*$")
248247
.alias("k")
249248
)
250249
.group_by("k")
@@ -263,7 +262,7 @@
263262
(
264263
"Q29",
265264
"SELECT SUM(ResolutionWidth), SUM(ResolutionWidth + 1), SUM(ResolutionWidth + 2), SUM(ResolutionWidth + 3), SUM(ResolutionWidth + 4), SUM(ResolutionWidth + 5), SUM(ResolutionWidth + 6), SUM(ResolutionWidth + 7), SUM(ResolutionWidth + 8), SUM(ResolutionWidth + 9), SUM(ResolutionWidth + 10), SUM(ResolutionWidth + 11), SUM(ResolutionWidth + 12), SUM(ResolutionWidth + 13), SUM(ResolutionWidth + 14), SUM(ResolutionWidth + 15), SUM(ResolutionWidth + 16), SUM(ResolutionWidth + 17), SUM(ResolutionWidth + 18), SUM(ResolutionWidth + 19), SUM(ResolutionWidth + 20), SUM(ResolutionWidth + 21), SUM(ResolutionWidth + 22), SUM(ResolutionWidth + 23), SUM(ResolutionWidth + 24), SUM(ResolutionWidth + 25), SUM(ResolutionWidth + 26), SUM(ResolutionWidth + 27), SUM(ResolutionWidth + 28), SUM(ResolutionWidth + 29), SUM(ResolutionWidth + 30), SUM(ResolutionWidth + 31), SUM(ResolutionWidth + 32), SUM(ResolutionWidth + 33), SUM(ResolutionWidth + 34), SUM(ResolutionWidth + 35), SUM(ResolutionWidth + 36), SUM(ResolutionWidth + 37), SUM(ResolutionWidth + 38), SUM(ResolutionWidth + 39), SUM(ResolutionWidth + 40), SUM(ResolutionWidth + 41), SUM(ResolutionWidth + 42), SUM(ResolutionWidth + 43), SUM(ResolutionWidth + 44), SUM(ResolutionWidth + 45), SUM(ResolutionWidth + 46), SUM(ResolutionWidth + 47), SUM(ResolutionWidth + 48), SUM(ResolutionWidth + 49), SUM(ResolutionWidth + 50), SUM(ResolutionWidth + 51), SUM(ResolutionWidth + 52), SUM(ResolutionWidth + 53), SUM(ResolutionWidth + 54), SUM(ResolutionWidth + 55), SUM(ResolutionWidth + 56), SUM(ResolutionWidth + 57), SUM(ResolutionWidth + 58), SUM(ResolutionWidth + 59), SUM(ResolutionWidth + 60), SUM(ResolutionWidth + 61), SUM(ResolutionWidth + 62), SUM(ResolutionWidth + 63), SUM(ResolutionWidth + 64), SUM(ResolutionWidth + 65), SUM(ResolutionWidth + 66), SUM(ResolutionWidth + 67), SUM(ResolutionWidth + 68), SUM(ResolutionWidth + 69), SUM(ResolutionWidth + 70), SUM(ResolutionWidth + 71), SUM(ResolutionWidth + 72), SUM(ResolutionWidth + 73), SUM(ResolutionWidth + 74), SUM(ResolutionWidth + 75), SUM(ResolutionWidth + 76), SUM(ResolutionWidth + 77), SUM(ResolutionWidth + 78), SUM(ResolutionWidth + 79), SUM(ResolutionWidth + 80), SUM(ResolutionWidth + 81), SUM(ResolutionWidth + 82), SUM(ResolutionWidth + 83), SUM(ResolutionWidth + 84), SUM(ResolutionWidth + 85), SUM(ResolutionWidth + 86), SUM(ResolutionWidth + 87), SUM(ResolutionWidth + 88), SUM(ResolutionWidth + 89) FROM hits;",
266-
lambda x: x.select([(pl.col("ResolutionWidth") + i).sum().alias(f"c_{i}") for i in range(90)]).collect(),
265+
lambda x: x.select(pl.sum_horizontal([pl.col("ResolutionWidth").shift(i) for i in range(1, 90)])).collect(),
267266
),
268267
(
269268
"Q30",
@@ -323,21 +322,16 @@
323322
lambda x: x.group_by("URL")
324323
.agg(pl.len().alias("c"))
325324
.sort("c", descending=True)
326-
.with_columns(pl.lit(1).alias("1"))
327325
.head(10).collect(),
328326
),
329327
(
330328
"Q35",
331329
"SELECT ClientIP, ClientIP - 1, ClientIP - 2, ClientIP - 3, COUNT(*) AS c FROM hits GROUP BY ClientIP, ClientIP - 1, ClientIP - 2, ClientIP - 3 ORDER BY c DESC LIMIT 10;",
332-
lambda x: x.group_by("ClientIP")
330+
lambda x: x.with_columns([pl.col("ClientIP")])
331+
.group_by(["ClientIP"])
333332
.agg(pl.len().alias("c"))
334-
.with_columns([
335-
(pl.col("ClientIP") - 1).alias("ClientIP_minus_1"),
336-
(pl.col("ClientIP") - 2).alias("ClientIP_minus_2"),
337-
(pl.col("ClientIP") - 3).alias("ClientIP_minus_3")
338-
])
339333
.sort("c", descending=True)
340-
.head(10).collect()
334+
.head(10).collect(),
341335
),
342336
(
343337
"Q36",
@@ -396,18 +390,15 @@
396390
& (pl.col("EventDate") <= date(2013, 7, 31))
397391
& (pl.col("IsRefresh") == 0)
398392
)
399-
.with_columns(
400-
pl.when(pl.col("SearchEngineID").eq(0) & pl.col("AdvEngineID").eq(0))
401-
.then(pl.col("Referer"))
402-
.otherwise(pl.lit(""))
403-
.alias("Src"),
404-
)
405393
.group_by(
406394
[
407395
"TraficSourceID",
408396
"SearchEngineID",
409397
"AdvEngineID",
410-
"Src",
398+
pl.when(pl.col("SearchEngineID").eq(0) & pl.col("AdvEngineID").eq(0))
399+
.then(pl.col("Referer"))
400+
.otherwise(pl.lit(""))
401+
.alias("Src"),
411402
"URL",
412403
]
413404
)
@@ -457,9 +448,8 @@
457448
& (pl.col("IsRefresh") == 0)
458449
& (pl.col("DontCountHits") == 0)
459450
)
460-
.group_by(pl.col("EventTime").dt.truncate("1m").alias("M"))
451+
.group_by(pl.col("EventTime").dt.truncate("1m"))
461452
.agg(pl.len().alias("PageViews"))
462-
.sort("M")
463453
.slice(1000, 10).collect(),
464454
),
465455
]

polars/query.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212

1313
# 0: No., 1: SQL, 2: Polars
1414
queries = [
15-
("Q0", "SELECT COUNT(*) FROM hits;", lambda x: x.select(pl.len()).collect().item()),
15+
("Q0", "SELECT COUNT(*) FROM hits;", lambda x: x.select(pl.len()).collect().height),
1616
(
1717
"Q1",
1818
"SELECT COUNT(*) FROM hits WHERE AdvEngineID <> 0;",
19-
lambda x: x.filter(pl.col("AdvEngineID") != 0).select(pl.len()).collect().item(),
19+
lambda x: x.select(pl.col("AdvEngineID").filter(pl.col("AdvEngineID") != 0).count()).collect().height,
2020
),
2121
(
2222
"Q2",
2323
"SELECT SUM(AdvEngineID), COUNT(*), AVG(ResolutionWidth) FROM hits;",
24-
lambda x: x.select(a_sum=pl.col("AdvEngineID").sum(), count=pl.len(), a_mean=pl.col("ResolutionWidth").mean()).collect().rows()[0],
24+
lambda x: x.select(a_sum=pl.col("AdvEngineID").sum(), count=pl.len(), a_mean=pl.col("AdvEngineID").mean()).collect().rows()[0],
2525
),
2626
(
2727
"Q3",
@@ -55,8 +55,8 @@
5555
"Q8",
5656
"SELECT RegionID, COUNT(DISTINCT UserID) AS u FROM hits GROUP BY RegionID ORDER BY u DESC LIMIT 10;",
5757
lambda x: x.group_by("RegionID")
58-
.agg(pl.col("UserID").n_unique().alias("u"))
59-
.sort("u", descending=True)
58+
.agg(pl.col("UserID").n_unique())
59+
.sort("UserID", descending=True)
6060
.head(10).collect(),
6161
),
6262
(
@@ -66,12 +66,11 @@
6666
.agg(
6767
[
6868
pl.sum("AdvEngineID").alias("AdvEngineID_sum"),
69-
pl.len().alias("count"),
7069
pl.mean("ResolutionWidth").alias("ResolutionWidth_mean"),
7170
pl.col("UserID").n_unique().alias("UserID_nunique"),
7271
]
7372
)
74-
.sort("count", descending=True)
73+
.sort("AdvEngineID_sum", descending=True)
7574
.head(10).collect(),
7675
),
7776
(
@@ -244,7 +243,7 @@
244243
x.filter(pl.col("Referer") != "")
245244
.with_columns(
246245
pl.col("Referer")
247-
.str.extract("(?-u)^https?://(?:www\\.)?([^/]+)/.*$")
246+
.str.extract(r"^https?://(?:www\\.)?([^/]+)/.*$")
248247
.alias("k")
249248
)
250249
.group_by("k")
@@ -263,7 +262,7 @@
263262
(
264263
"Q29",
265264
"SELECT SUM(ResolutionWidth), SUM(ResolutionWidth + 1), SUM(ResolutionWidth + 2), SUM(ResolutionWidth + 3), SUM(ResolutionWidth + 4), SUM(ResolutionWidth + 5), SUM(ResolutionWidth + 6), SUM(ResolutionWidth + 7), SUM(ResolutionWidth + 8), SUM(ResolutionWidth + 9), SUM(ResolutionWidth + 10), SUM(ResolutionWidth + 11), SUM(ResolutionWidth + 12), SUM(ResolutionWidth + 13), SUM(ResolutionWidth + 14), SUM(ResolutionWidth + 15), SUM(ResolutionWidth + 16), SUM(ResolutionWidth + 17), SUM(ResolutionWidth + 18), SUM(ResolutionWidth + 19), SUM(ResolutionWidth + 20), SUM(ResolutionWidth + 21), SUM(ResolutionWidth + 22), SUM(ResolutionWidth + 23), SUM(ResolutionWidth + 24), SUM(ResolutionWidth + 25), SUM(ResolutionWidth + 26), SUM(ResolutionWidth + 27), SUM(ResolutionWidth + 28), SUM(ResolutionWidth + 29), SUM(ResolutionWidth + 30), SUM(ResolutionWidth + 31), SUM(ResolutionWidth + 32), SUM(ResolutionWidth + 33), SUM(ResolutionWidth + 34), SUM(ResolutionWidth + 35), SUM(ResolutionWidth + 36), SUM(ResolutionWidth + 37), SUM(ResolutionWidth + 38), SUM(ResolutionWidth + 39), SUM(ResolutionWidth + 40), SUM(ResolutionWidth + 41), SUM(ResolutionWidth + 42), SUM(ResolutionWidth + 43), SUM(ResolutionWidth + 44), SUM(ResolutionWidth + 45), SUM(ResolutionWidth + 46), SUM(ResolutionWidth + 47), SUM(ResolutionWidth + 48), SUM(ResolutionWidth + 49), SUM(ResolutionWidth + 50), SUM(ResolutionWidth + 51), SUM(ResolutionWidth + 52), SUM(ResolutionWidth + 53), SUM(ResolutionWidth + 54), SUM(ResolutionWidth + 55), SUM(ResolutionWidth + 56), SUM(ResolutionWidth + 57), SUM(ResolutionWidth + 58), SUM(ResolutionWidth + 59), SUM(ResolutionWidth + 60), SUM(ResolutionWidth + 61), SUM(ResolutionWidth + 62), SUM(ResolutionWidth + 63), SUM(ResolutionWidth + 64), SUM(ResolutionWidth + 65), SUM(ResolutionWidth + 66), SUM(ResolutionWidth + 67), SUM(ResolutionWidth + 68), SUM(ResolutionWidth + 69), SUM(ResolutionWidth + 70), SUM(ResolutionWidth + 71), SUM(ResolutionWidth + 72), SUM(ResolutionWidth + 73), SUM(ResolutionWidth + 74), SUM(ResolutionWidth + 75), SUM(ResolutionWidth + 76), SUM(ResolutionWidth + 77), SUM(ResolutionWidth + 78), SUM(ResolutionWidth + 79), SUM(ResolutionWidth + 80), SUM(ResolutionWidth + 81), SUM(ResolutionWidth + 82), SUM(ResolutionWidth + 83), SUM(ResolutionWidth + 84), SUM(ResolutionWidth + 85), SUM(ResolutionWidth + 86), SUM(ResolutionWidth + 87), SUM(ResolutionWidth + 88), SUM(ResolutionWidth + 89) FROM hits;",
266-
lambda x: x.select([(pl.col("ResolutionWidth") + i).sum().alias(f"c_{i}") for i in range(90)]).collect(),
265+
lambda x: x.select(pl.sum_horizontal([pl.col("ResolutionWidth").shift(i) for i in range(1, 90)])).collect(),
267266
),
268267
(
269268
"Q30",
@@ -323,21 +322,16 @@
323322
lambda x: x.group_by("URL")
324323
.agg(pl.len().alias("c"))
325324
.sort("c", descending=True)
326-
.with_columns(pl.lit(1).alias("1"))
327325
.head(10).collect(),
328326
),
329327
(
330328
"Q35",
331329
"SELECT ClientIP, ClientIP - 1, ClientIP - 2, ClientIP - 3, COUNT(*) AS c FROM hits GROUP BY ClientIP, ClientIP - 1, ClientIP - 2, ClientIP - 3 ORDER BY c DESC LIMIT 10;",
332-
lambda x: x.group_by("ClientIP")
330+
lambda x: x.with_columns([pl.col("ClientIP")])
331+
.group_by(["ClientIP"])
333332
.agg(pl.len().alias("c"))
334-
.with_columns([
335-
(pl.col("ClientIP") - 1).alias("ClientIP_minus_1"),
336-
(pl.col("ClientIP") - 2).alias("ClientIP_minus_2"),
337-
(pl.col("ClientIP") - 3).alias("ClientIP_minus_3")
338-
])
339333
.sort("c", descending=True)
340-
.head(10).collect()
334+
.head(10).collect(),
341335
),
342336
(
343337
"Q36",
@@ -396,18 +390,15 @@
396390
& (pl.col("EventDate") <= date(2013, 7, 31))
397391
& (pl.col("IsRefresh") == 0)
398392
)
399-
.with_columns(
400-
pl.when(pl.col("SearchEngineID").eq(0) & pl.col("AdvEngineID").eq(0))
401-
.then(pl.col("Referer"))
402-
.otherwise(pl.lit(""))
403-
.alias("Src"),
404-
)
405393
.group_by(
406394
[
407395
"TraficSourceID",
408396
"SearchEngineID",
409397
"AdvEngineID",
410-
"Src",
398+
pl.when(pl.col("SearchEngineID").eq(0) & pl.col("AdvEngineID").eq(0))
399+
.then(pl.col("Referer"))
400+
.otherwise(pl.lit(""))
401+
.alias("Src"),
411402
"URL",
412403
]
413404
)
@@ -457,9 +448,8 @@
457448
& (pl.col("IsRefresh") == 0)
458449
& (pl.col("DontCountHits") == 0)
459450
)
460-
.group_by(pl.col("EventTime").dt.truncate("1m").alias("M"))
451+
.group_by(pl.col("EventTime").dt.truncate("1m"))
461452
.agg(pl.len().alias("PageViews"))
462-
.sort("M")
463453
.slice(1000, 10).collect(),
464454
),
465455
]
@@ -478,7 +468,6 @@ def run_timings(lf: pl.LazyFrame) -> None:
478468
times.append(round(end - start, 3))
479469
print(f"{times},")
480470

481-
482471
data_size = os.path.getsize("hits.parquet")
483472

484473
# Run from Parquet

0 commit comments

Comments
 (0)