Skip to content

Commit 1f7355e

Browse files
committed
Add an example for dividing a group into N sub-groups with equal or almost equal number of rows.
1 parent 3da2c00 commit 1f7355e

1 file changed

Lines changed: 141 additions & 0 deletions

File tree

pyspark_cookbook.org

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3514,6 +3514,147 @@ New group IDs:
35143514
|group_B|sUfvt|5 |[3, 2] |2 |[3, 5] |5 |[1, 0] |2 |group_B2 |
35153515
:end:
35163516

3517+
** To divide a group into N sub-groups with equal or almost equal number of rows
3518+
#+BEGIN_SRC python :post pretty2orgtbl(data=*this*)
3519+
from pyspark.sql import SparkSession
3520+
from pyspark.sql.window import Window
3521+
import pyspark.sql.functions as F
3522+
import string
3523+
3524+
def divide_into_similar_groups(n, k):
3525+
"""
3526+
Divide an integer "n" into a set of possibly "k" repeatable integers given
3527+
a constraint that all integers should be equal or differ by maximum of 1.
3528+
"""
3529+
t = int(n/k)
3530+
r = n % k
3531+
return [t] * (k - r) + [t + 1] * r
3532+
3533+
3534+
@F.udf(returnType=T.ArrayType(T.IntegerType()))
3535+
def udf_divide_into_similar_groups(n, k):
3536+
return divide_into_similar_groups(n, k)
3537+
3538+
3539+
spark = SparkSession.builder.appName("LabeledRandomIDs").getOrCreate()
3540+
3541+
# Character set for random ID generation
3542+
CHAR_SET = string.ascii_letters + string.digits
3543+
CHAR_LIST = list(CHAR_SET)
3544+
ID_LENGTH = 5
3545+
3546+
data = [("group_A",) for _ in range(8)] + [("group_B",) for _ in range(5)]
3547+
df = spark.createDataFrame(data, ["group"])
3548+
3549+
# Add random 5-character ID by selecting characters from CHAR_LIST
3550+
for i in range(ID_LENGTH):
3551+
df = df.withColumn(
3552+
f"char_{i}",
3553+
F.element_at(
3554+
F.array([F.lit(c) for c in CHAR_LIST]),
3555+
(F.rand() * len(CHAR_LIST)).cast("int") + 1
3556+
)
3557+
)
3558+
3559+
# Concatenate characters into a single string ID
3560+
df = df.withColumn("id", F.concat_ws("", *[F.col(f"char_{i}") for i in range(ID_LENGTH)]))
3561+
3562+
txt = "Initial groups and IDs:"
3563+
<<txtblk("txt")>>print(txt)
3564+
df = df.select("group", "id")
3565+
df.show(truncate=False)
3566+
3567+
max_num_sub_groups = 2
3568+
txt = f"Maximal number of sub-groups desired: {max_num_sub_groups}"
3569+
<<txtblk("txt")>>print(txt)
3570+
3571+
df_stat = df.groupBy("group").agg(F.count("id").alias("n_ids"))
3572+
df_stat = df_stat.withColumn("terms", udf_divide_into_similar_groups(F.col("n_ids"), F.lit(max_num_sub_groups)))
3573+
df_stat = df_stat.withColumn("n_terms", F.size("terms"))
3574+
empty_int_array = F.array().cast(T.ArrayType(T.IntegerType()))
3575+
df_stat = df_stat.withColumn(
3576+
"cum_sum",
3577+
F.aggregate(
3578+
F.col("terms"),
3579+
empty_int_array,
3580+
lambda acc, x: F.concat(
3581+
acc,
3582+
F.array(x + F.coalesce(F.element_at(acc, -1), F.lit(0).cast("int")))
3583+
)
3584+
)
3585+
)
3586+
txt = "Number of IDs per group:"
3587+
<<txtblk("txt")>>print(txt)
3588+
df_stat.orderBy("group").show(truncate=False)
3589+
3590+
df = df.join(df_stat, on="group", how="left")
3591+
3592+
w = Window.partitionBy("group").orderBy(F.asc("id"))
3593+
df = df.withColumn("rank", F.dense_rank().over(w))
3594+
df = df.withColumn("loc", F.transform("cum_sum", lambda x: (x < F.col("rank")).cast("int")))
3595+
df = df.withColumn("pos", F.array_position(F.col("loc"), 0))
3596+
df = df.withColumn("new_group", F.concat("group", F.col("pos").cast("string")))
3597+
txt = "New group IDs:"
3598+
<<txtblk("txt")>>print(txt)
3599+
df.show(truncate=False)
3600+
#+END_SRC
3601+
3602+
#+RESULTS:
3603+
:results:
3604+
#+begin_src text
3605+
Initial groups and IDs:
3606+
#+end_src
3607+
3608+
|group |id |
3609+
|-------+-----|
3610+
|group_A|GN0Ax|
3611+
|group_A|grLyR|
3612+
|group_A|eo47Y|
3613+
|group_A|Adm4Z|
3614+
|group_A|KCEUD|
3615+
|group_A|I1M9Z|
3616+
|group_A|nKb8g|
3617+
|group_A|3y7xZ|
3618+
|group_B|paho5|
3619+
|group_B|I6WmG|
3620+
|group_B|8IA5f|
3621+
|group_B|Pol35|
3622+
|group_B|RLxBg|
3623+
3624+
#+begin_src text
3625+
Maximal number of sub-groups desired: 2
3626+
#+end_src
3627+
3628+
#+begin_src text
3629+
Number of IDs per group:
3630+
#+end_src
3631+
3632+
|group |n_ids|terms |n_terms|cum_sum|
3633+
|-------+-----+------+-------+-------|
3634+
|group_A|8 |[4, 4]|2 |[4, 8] |
3635+
|group_B|5 |[2, 3]|2 |[2, 5] |
3636+
3637+
#+begin_src text
3638+
New group IDs:
3639+
#+end_src
3640+
3641+
|group |id |n_ids|terms |n_terms|cum_sum|rank|loc |pos|new_group|
3642+
|-------+-----+-----+------+-------+-------+----+------+---+---------|
3643+
|group_A|3y7xZ|8 |[4, 4]|2 |[4, 8] |1 |[0, 0]|1 |group_A1 |
3644+
|group_A|Adm4Z|8 |[4, 4]|2 |[4, 8] |2 |[0, 0]|1 |group_A1 |
3645+
|group_A|GN0Ax|8 |[4, 4]|2 |[4, 8] |3 |[0, 0]|1 |group_A1 |
3646+
|group_A|I1M9Z|8 |[4, 4]|2 |[4, 8] |4 |[0, 0]|1 |group_A1 |
3647+
|group_A|KCEUD|8 |[4, 4]|2 |[4, 8] |5 |[1, 0]|2 |group_A2 |
3648+
|group_A|eo47Y|8 |[4, 4]|2 |[4, 8] |6 |[1, 0]|2 |group_A2 |
3649+
|group_A|grLyR|8 |[4, 4]|2 |[4, 8] |7 |[1, 0]|2 |group_A2 |
3650+
|group_A|nKb8g|8 |[4, 4]|2 |[4, 8] |8 |[1, 0]|2 |group_A2 |
3651+
|group_B|8IA5f|5 |[2, 3]|2 |[2, 5] |1 |[0, 0]|1 |group_B1 |
3652+
|group_B|I6WmG|5 |[2, 3]|2 |[2, 5] |2 |[0, 0]|1 |group_B1 |
3653+
|group_B|Pol35|5 |[2, 3]|2 |[2, 5] |3 |[1, 0]|2 |group_B2 |
3654+
|group_B|RLxBg|5 |[2, 3]|2 |[2, 5] |4 |[1, 0]|2 |group_B2 |
3655+
|group_B|paho5|5 |[2, 3]|2 |[2, 5] |5 |[1, 0]|2 |group_B2 |
3656+
:end:
3657+
35173658
** To calculate set intersection between arrays in two consecutive rows in a window
35183659
#+BEGIN_SRC python :post pretty2orgtbl(data=*this*)
35193660
import pyspark.sql.functions as F

0 commit comments

Comments
 (0)