ESIEE Paris — Data Engineering I — Assignment 3
Author : DIALLO Samba & DIOP Mouhamed
Academic year: 2025–2026
Program: Data & Applications - Engineering - (FD)
Course: Data Engineering I
Learning goals
- Analyze with SQL and DataFrames.
- Implement two RDD means variants.
- Implement RDD joins (shuffle and hash).
- Record and explain performance observations.
1. Setup
Download data files from the following URL: https://www.dropbox.com/scl/fi/7012u693u06dgj95mgq2a/retail_dw_20250826.tar.gz?rlkey=fxyozuoryn951gzwmli5xi2zd&dl=0
Unpack somewhere and define the data_path accordingly:
# Change to path on your local machine.
data_path = "/home/sable/devops_base/td2/retail_dw_20250826"The following cell contains setup to measure wall clock time and memory usage. (Don’t worry about the details, just run the cell)
!pip install -U numpy pandas pyarrow matplotlib scipy
import sys, subprocess
try:
import psutil # noqa: F401
except Exception:
subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"])
print("psutil is installed.")
from IPython.core.magic import register_cell_magic
import time, os, platform
# Try to import optional modules
try:
import psutil
except Exception:
psutil = None
try:
import resource # not available on Windows
except Exception:
resource = None
def _rss_bytes():
"""Resident Set Size in bytes (cross-platform via psutil if available)."""
if psutil is not None:
return psutil.Process(os.getpid()).memory_info().rss
# Fallback: unknown RSS → 0
return 0
def _peak_bytes():
"""
Best-effort peak memory in bytes.
- Windows: psutil peak working set (peak_wset)
- Linux: resource.ru_maxrss (KB → bytes)
- macOS: resource.ru_maxrss (bytes)
Fallback to current RSS if unavailable.
"""
sysname = platform.system()
# Windows path: use psutil peak_wset if present
if sysname == "Windows" and psutil is not None:
mi = psutil.Process(os.getpid()).memory_info()
peak = getattr(mi, "peak_wset", None) # should be available on Windows
if peak is not None:
return int(peak)
return int(mi.rss)
# POSIX path: resource may be available
if resource is not None:
try:
ru = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
# On Linux ru_maxrss is in kilobytes; on macOS/BSD it is bytes
if sysname == "Linux":
return int(ru) * 1024
else:
return int(ru)
except Exception:
pass
# Last resort
return _rss_bytes()
@register_cell_magic
def timemem(line, cell):
"""
Measure wall time and memory around the execution of this cell.
timemem
# codecell_31a - Setup and Data Loading
data_path = "/home/sable/devops_base/td2/retail_dw_20250826"
events_df = spark.read.parquet(os.path.join(data_path, "retail_dw_20250826_events"))
products_df = spark.read.parquet(os.path.join(data_path, "retail_dw_20250826_products"))
brands_df = spark.read.parquet(os.path.join(data_path, "retail_dw_20250826_brands"))
events_df.createOrReplaceTempView("events")
products_df.createOrReplaceTempView("products")
brands_df.createOrReplaceTempView("brands")
print(f"Spark version: {spark.version}")
print(f"Events count: {events_df.count()}")
print(f"Products count: {products_df.count()}")
print(f"Brands count: {brands_df.count()}")Spark version: 4.0.1
Events count: 42351862
Products count: 166794
Brands count: 3444
======================================
Wall time: 0.775 s
RSS Δ: +0.01 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 7617641f2080, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7617641f0790, raw_cell="# codecell_31a - Setup and Data Loading
data_path.." store_history=False silent=False shell_futures=True cell_id=None> result=None>
2. Loading DataFrames
Let’s load the DataFrames and print out their schemas:
# Note that you should have defined data_path above
events_df = spark.read.parquet(os.path.join(data_path, "retail_dw_20250826_events"))
products_df = spark.read.parquet(os.path.join(data_path, "retail_dw_20250826_products"))
brands_df = spark.read.parquet(os.path.join(data_path, "retail_dw_20250826_brands"))
events_df.printSchema()
products_df.printSchema()
brands_df.printSchema()root
|-- date_key: integer (nullable = true)
|-- user_key: integer (nullable = true)
|-- age_key: integer (nullable = true)
|-- product_key: integer (nullable = true)
|-- brand_key: integer (nullable = true)
|-- category_key: integer (nullable = true)
|-- session_id: string (nullable = true)
|-- event_time: timestamp (nullable = true)
|-- event_type: string (nullable = true)
|-- price: double (nullable = true)
root
|-- category_code: string (nullable = true)
|-- brand_code: string (nullable = true)
|-- product_id: integer (nullable = true)
|-- product_name: string (nullable = true)
|-- product_desc: string (nullable = true)
|-- brand_key: integer (nullable = true)
|-- category_key: integer (nullable = true)
|-- product_key: integer (nullable = true)
root
|-- brand_code: string (nullable = true)
|-- brand_desc: string (nullable = true)
|-- brand_key: integer (nullable = true)
How many rows are in each table?
print(f"Number of rows in events table: {events_df.count()}")
print(f"Number of rows in products table: {products_df.count()}")
print(f"Number of rows in brands table: {brands_df.count()}")Number of rows in events table: 42351862
Number of rows in products table: 166794
Number of rows in brands table: 3444
We can register the DataFrames as tables and issue SQL queries:
events_df.createOrReplaceTempView("events")
products_df.createOrReplaceTempView("products")
brands_df.createOrReplaceTempView("brands")
spark.sql('select count(*) from events').show()
spark.sql('select count(*) from products').show()
spark.sql('select count(*) from brands').show()+--------+
|count(1)|
+--------+
|42351862|
+--------+
+--------+
|count(1)|
+--------+
| 166794|
+--------+
+--------+
|count(1)|
+--------+
| 3444|
+--------+
As a sanity check, the corresponding values should match: counting the rows in the DataFrame vs. issuing an SQL query to count the number of rows.
3. Data Science
Answer Q1 to Q7 below with SQL queries and DataFrame manipulations.
write some code here
3.1 Q1
For session_id 789d3699-028e-4367-b515-b82e2cb5225f, what was the purchase price?
Hint: We only care about purchase events.
First, do it using SQL:
timemem
# codecell_31b (keep this id for tracking purposes)
results_df = (
events_df
.filter((F.col("session_id") == "789d3699-028e-4367-b515-b82e2cb5225f") &
(F.col("event_type") == "purchase"))
.select("price")
.orderBy(F.col("event_time").desc())
.limit(1)
)
results_df.show()+------+
| price|
+------+
|100.39|
+------+
======================================
Wall time: 0.780 s
RSS Δ: +0.02 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 761766aa9de0, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 761766aa93c0, raw_cell="# codecell_31b (keep this id for tracking purposes.." store_history=False silent=False shell_futures=True cell_id=None> result=None>
3.2 Q2
How many products are sold by the brand “sokolov”?
First, do it using SQL:
timemem
# codecell_32b (keep this id for tracking purposes)
results_df = (
products_df
.filter(F.col("brand_code") == "sokolov")
.select("product_id")
.distinct()
.agg(F.count("product_id").alias("num_products"))
)
results_df.show()+------------+
|num_products|
+------------+
| 1601|
+------------+
======================================
Wall time: 0.131 s
RSS Δ: +0.00 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 7617641f2b90, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7617641f1810, raw_cell="# codecell_32b (keep this id for tracking purposes.." store_history=False silent=False shell_futures=True cell_id=None> result=None>
3.3 Q3
What is the average purchase price of items purchased from the brand “febest”? (Report answer to two digits after the decimal point, i.e., XX.XX.)
First, do it using SQL:
timemem
# codecell_33b (keep this id for tracking purposes)
results_df = (
events_df
.filter((F.col("event_type") == "purchase") & F.col("price").isNotNull())
.join(products_df, on="product_key")
.filter(F.col("brand_code") == "febest")
.agg(F.round(F.avg("price"), 2).alias("avg_price"))
)
results_df.show()[Stage 52:========================> (9 + 8) / 21]
+---------+
|avg_price|
+---------+
| 20.39|
+---------+
======================================
Wall time: 1.321 s
RSS Δ: +0.00 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 7617641f0cd0, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7617641f0640, raw_cell="# codecell_33b (keep this id for tracking purposes.." store_history=False silent=False shell_futures=True cell_id=None> result=None>
3.4 Q4
What is the average number of events per user? (Report answer to two digits after the decimal point, i.e., XX.XX.)
First, do it using SQL:
timemem
# codecell_34b (keep this id for tracking purposes)
results_df = (
events_df
.groupBy("user_key")
.count()
.agg(F.round(F.avg("count"), 2).alias("avg_events_per_user"))
)
results_df.show()[Stage 61:=============================================> (17 + 4) / 21]
+-------------------+
|avg_events_per_user|
+-------------------+
| 14.02|
+-------------------+
======================================
Wall time: 1.978 s
RSS Δ: +0.00 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 7617641f2f80, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7617641f2a40, raw_cell="# codecell_34b (keep this id for tracking purposes.." store_history=False silent=False shell_futures=True cell_id=None> result=None>
3.5 Q5
What are the top 10 (product_name, brand_code) pairs in terms of revenue? We want the answer rows sorted by revenue in descending order.
First, do it using SQL:
timemem
# codecell_35b (keep this id for tracking purposes)
results_df = (
events_df
.filter((F.col("event_type") == "purchase") & F.col("price").isNotNull())
.join(products_df, on="product_key")
.groupBy("product_name", "brand_code")
.agg(F.round(F.sum("price"), 2).alias("total_revenue"))
.orderBy(F.desc("total_revenue"))
.limit(10)
)
results_df.show(truncate=False)[Stage 72:=============================================> (17 + 4) / 21]
+------------+----------+--------------+
|product_name|brand_code|total_revenue |
+------------+----------+--------------+
|smartphone |apple |1.6711340803E8|
|smartphone |samsung |9.546627508E7 |
|smartphone |xiaomi |2.254972634E7 |
|NULL |NULL |1.673241267E7 |
|smartphone |huawei |1.363398709E7 |
|video.tv |samsung |1.220999247E7 |
|smartphone |NULL |1.199712625E7 |
|NULL |lucente |9556989.32 |
|notebook |acer |8963128.65 |
|clocks |apple |8622900.64 |
+------------+----------+--------------+
======================================
Wall time: 1.917 s
RSS Δ: +0.00 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 7617641f35e0, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7617641f0a90, raw_cell="# codecell_35b (keep this id for tracking purposes.." store_history=False silent=False shell_futures=True cell_id=None> result=None>
3.6 Q6
Tally up counts of events by hour. More precisely, we want a table with hours 0, 1, … 23 with the counts of events in that hour.
First, do it using SQL:
timemem
# codecell_36b (keep this id for tracking purposes)
events_by_hour_df = (
events_df
.withColumn("hour", F.hour("event_time"))
.groupBy("hour")
.count()
.orderBy("hour")
)
events_by_hour_df.show(24)[Stage 78:=============================================> (17 + 4) / 21]
+----+-------+
|hour| count|
+----+-------+
| 0| 263808|
| 1| 223635|
| 2| 353509|
| 3| 623434|
| 4|1137209|
| 5|1605037|
| 6|1955461|
| 7|2131930|
| 8|2269469|
| 9|2332649|
| 10|2380185|
| 11|2335494|
| 12|2282992|
| 13|2181477|
| 14|2171196|
| 15|2407266|
| 16|2717710|
| 17|2988054|
| 18|3008559|
| 19|2631424|
| 20|1999466|
| 21|1244129|
| 22| 694728|
| 23| 413041|
+----+-------+
======================================
Wall time: 2.212 s
RSS Δ: +0.00 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 7617641f2770, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 761761403fa0, raw_cell="# codecell_36b (keep this id for tracking purposes.." store_history=False silent=False shell_futures=True cell_id=None> result=None>
When you run the cell above, events_by_hour_df should be something like:
+----+-------+
|hour| count|
+----+-------+
| 0| ???|
| 1| ???|
...
| 23| ???|
+----+-------+
Now plot the above DataFrame using matplotlib.
Here we want a line graph, with hour on the x axis and count on the y axis.
Hint: use the code below to get started.
timemem
# codecell_37a (keep this id for tracking purposes)
# Write your SQL below
sql_query = """
SELECT
p.brand_code,
ROUND(AVG(e.price), 2) as avg_price
FROM events e
JOIN products p ON e.product_key = p.product_key
WHERE e.event_type = 'purchase'
AND e.price IS NOT NULL
GROUP BY p.brand_code
HAVING AVG(e.price) > 10000
ORDER BY avg_price DESC
"""
results = spark.sql(sql_query)
results.show()[Stage 90:==================================================> (19 + 2) / 21]
+----------+---------+
|brand_code|avg_price|
+----------+---------+
| adam| 58946.0|
| kona| 43759.0|
| yuandong| 35329.0|
| bentley| 23164.0|
| otex| 18633.14|
| suunto| 10732.82|
| stark| 10400.25|
+----------+---------+
======================================
Wall time: 2.709 s
RSS Δ: +0.00 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 761766aab430, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 761766aa9300, raw_cell="# codecell_37a (keep this id for tracking purposes.." store_history=False silent=False shell_futures=True cell_id=None> result=None>
Next, do it with DataFrames:
timemem
# codecell_37c (keep this id for tracking purposes)
import matplotlib.pyplot as plt
avg_price_by_brand_pdf = avg_price_by_brand_df.toPandas()
plt.figure(figsize=(12, 6))
plt.bar(avg_price_by_brand_pdf['brand_code'], avg_price_by_brand_pdf['avg_price'], color='steelblue')
plt.xlabel('Brand Code', fontsize=12)
plt.ylabel('Average Price', fontsize=12)
plt.title('Average Purchase Price by Brand (> 10K)', fontsize=14)
plt.xticks(rotation=45, ha='right')
plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
======================================
Wall time: 1.887 s
RSS Δ: +0.05 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 7617642485b0, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 76176424ae30, raw_cell="# codecell_37c (keep this id for tracking purposes.." store_history=False silent=False shell_futures=True cell_id=None> result=None>
4. Load RDDs
The remaining exercises focus on RDD manipulations.
Let’s start by loading the RDDs.
# Get RDDs directly from DataFrames (with required repartitions)
# type: RDD[Row]
events_rdd = events_df.rdd.repartition(1000)
products_rdd = products_df.rdd.repartition(100)
brands_rdd = brands_df.rdd.repartition(100)You’ll need Row, so let’s make sure we’ve imported it.
from pyspark.sql import Row5. Implementations of Computing Averages
In this next exercise, we’re going to implement “computing the mean” (version 1) and (version 3) in Spark as described in the second lecture Batch Processing I (please use ctrl+f to reach the slide with the title : “Computing the Mean: Version 1” or “Computing the Mean: Version 3”.
To make the problem more tractable (i.e., to reduce the running times), let’s first do a bit of filtering of the events table.
We’ll do this using DataFrames, and then generate an RDD:
You can confirm that we’re working with a smaller dataset.
Compute the average purchase price by brand. We want the results sorted by the average purchase price from the largest to smallest value. As before, round to two digits after the decimal point. This is similar to Q7 above, except without the “more than 10K” condition.
Implement using the naive “version 1” algorithm, as described in the lectures:
- You must start with
filtered_events_rdd. - You must use
groupByKey(). - Per “version 1”, your implementation must shuffle all values from the “mappers” to the “reducers”.
write some code here
timemem
# codecell_5x1 (keep this id for tracking purposes)
average_revenue_per_brand_v1 = (
filtered_events_rdd
.map(lambda row: (row['brand_code'], row['price']))
.groupByKey()
.mapValues(lambda prices: sum(prices) / len(list(prices)))
.map(lambda x: (x[0], round(x[1], 2)))
.sortBy(lambda x: x[1], ascending=False)
)
average_revenue_per_brand_v1.take(10)[('adam', 58946.0),
('kona', 43759.0),
('yuandong', 35329.0),
('bentley', 23164.0),
('otex', 18633.13),
('suunto', 10732.82),
('stark', 10400.25),
('zenmart', 9447.0),
('baltekstil', 8504.19),
('bugati', 8288.42)]
======================================
Wall time: 12.619 s
RSS Δ: +0.02 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 7617614366b0, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 761761436500, raw_cell="# codecell_5x1 (keep this id for tracking purposes.." store_history=False silent=False shell_futures=True cell_id=None> result=[('adam', 58946.0), ('kona', 43759.0), ('yuandong', 35329.0), ('bentley', 23164.0), ('otex', 18633.13), ('suunto', 10732.82), ('stark', 10400.25), ('zenmart', 9447.0), ('baltekstil', 8504.19), ('bugati', 8288.42)]>
Compute the average purchase price by brand. We want the results sorted by the average purchase price from the largest to smallest value. As before, round to two digits after the decimal point. This is similar to Q7 above, except without the “more than 10K” condition.
Implement using the improved “version 3” algorithm, as described in the lectures:
- You must start with
filtered_events_rdd. - You must use
reduceByKey(). - Per “version 3”, your implementation must emit
(sum, count)pairs and take advantage opportunities to perform aggregations.
write some code here
timemem
shuffle_join_rdd = shuffle_join(brands_rdd, products_rdd, "brand_key", "brand_key")
shuffle_join_rdd.count()115584
======================================
Wall time: 36.276 s
RSS Δ: +0.00 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 76175d4cfdc0, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 76175d4ced70, raw_cell="
shuffle_join_rdd = shuffle_join(brands_rdd, produ.." store_history=False silent=False shell_futures=True cell_id=None> result=115584>
Add in the WHERE clause:
shuffle_join_results_rdd = shuffle_join_rdd.filter(lambda row: row["brand_key"] == 423)
shuffle_join_results_rdd.count()2
If you look at the results, they’re a bit difficult to read… why don’t we just use Spark DataFrames for prettification?
df = spark.createDataFrame(shuffle_join_results_rdd.collect())
df.show()+----------+--------------------+---------+-------------+----------+------------+--------------------+------------+-----------+
|brand_code| brand_desc|brand_key|category_code|product_id|product_name| product_desc|category_key|product_key|
+----------+--------------------+---------+-------------+----------+------------+--------------------+------------+-----------+
| blaupunkt|"Blaupunkt is a G...| 423| electronics| 1802099| video.tv|The video.tv is a...| 8| 4813|
| blaupunkt|"Blaupunkt is a G...| 423| electronics| 1802107| video.tv|The video.tv is a...| 8| 4821|
+----------+--------------------+---------+-------------+----------+------------+--------------------+------------+-----------+
timemem
replicated_hash_join_rdd = replicated_hash_join(brands_rdd, products_rdd, "brand_key", "brand_key")
replicated_hash_join_rdd.count()115584
======================================
Wall time: 5.563 s
RSS Δ: +1.08 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 76176426eda0, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 76176426d030, raw_cell="
replicated_hash_join_rdd = replicated_hash_join(b.." store_history=False silent=False shell_futures=True cell_id=None> result=115584>
Add in the WHERE clause:
replicated_hash_join_results_rdd = replicated_hash_join_rdd.filter(lambda row: row["brand_key"] == 423)
replicated_hash_join_results_rdd.count()2
If you look at the results, they’re a bit difficult to read… why don’t we just use Spark DataFrames for prettification?
df = spark.createDataFrame(replicated_hash_join_results_rdd.collect())
df.show()+----------+--------------------+---------+-------------+----------+------------+--------------------+------------+-----------+
|brand_code| brand_desc|brand_key|category_code|product_id|product_name| product_desc|category_key|product_key|
+----------+--------------------+---------+-------------+----------+------------+--------------------+------------+-----------+
| blaupunkt|"Blaupunkt is a G...| 423| electronics| 1802099| video.tv|The video.tv is a...| 8| 4813|
| blaupunkt|"Blaupunkt is a G...| 423| electronics| 1802107| video.tv|The video.tv is a...| 8| 4821|
+----------+--------------------+---------+-------------+----------+------------+--------------------+------------+-----------+
Verify output against the SQL query.
7. Join Performance
Now that we have two different implementations of joins, let’s compare them, on the same exact query. The first two are repeated from above.
Let’s call this J1 below. (Run the cell, it should just work. If it doesn’t you’ll need to fix the implementation above.)
timemem
replicated_hash_join_rdd = replicated_hash_join(brands_rdd, products_rdd, "brand_key", "brand_key").filter(lambda row: row["brand_key"] == 423)
replicated_hash_join_rdd.count()2
======================================
Wall time: 5.639 s
RSS Δ: +0.01 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 76177f7b1a50, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 76177f7b19c0, raw_cell="
replicated_hash_join_rdd = replicated_hash_join(b.." store_history=False silent=False shell_futures=True cell_id=None> result=2>
Let’s call this J3 below. (Run the cell, it should just work. If it doesn’t you’ll need to fix the implementation above.)
timemem
spark.stop()======================================
Wall time: 0.938 s
RSS Δ: -0.01 MB
Peak memory Δ: +0.00 MB (OS-dependent)
======================================
<ExecutionResult object at 761762795870, execution_count=None error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 761762795960, raw_cell="spark.stop()
" store_history=False silent=False shell_futures=True cell_id=None> result=None>
Performance notes
- Set and justify
spark.sql.shuffle.partitionsfor local vs. cluster runs. - Prefer DataFrame built-ins over Python UDFs; push logic to Catalyst when possible.
- Use AQE (adaptive query execution) to mitigate skew; consider salting for extreme keys.
- Cache only when reuse exists; unpersist when no longer needed.
- Use broadcast join only when the small side fits in memory; verify with
explain. - Capture
df.explain(mode='formatted')for at least one analysis query and one join. - A3 note: Python RDDs cross the Python/JVM boundary; slower runtimes are expected for the RDD parts.
Self-check (toy data)
# Redémarrer Spark si nécessaire
try:
_ = spark.sparkContext.defaultParallelism
print("Spark is already running")
except (AttributeError, Exception):
print("Restarting Spark...")
spark = SparkSession.builder \
.appName("A3-SelfCheck") \
.master("local[*]") \
.config("spark.driver.memory", "4g") \
.getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
print("Spark restarted successfully")
# ============================================
# Test Mean Functions
# ============================================
print("\n" + "="*50)
print("Testing Mean Functions (Version 1 vs Version 3)")
print("="*50)
test_data = [
Row(brand_code='A', price=100.0),
Row(brand_code='A', price=200.0),
Row(brand_code='B', price=300.0),
Row(brand_code='B', price=400.0),
]
test_rdd = spark.sparkContext.parallelize(test_data)
# Version 1 (groupByKey)
print("\n--- Version 1 (groupByKey) ---")
mean_v1 = (
test_rdd
.map(lambda row: (row['brand_code'], row['price']))
.groupByKey()
.mapValues(lambda prices: round(sum(prices) / len(list(prices)), 2))
.sortByKey()
.collect()
)
print(f"V1 Results: {mean_v1}")
print("Expected: [('A', 150.0), ('B', 350.0)]")
# Version 3 (reduceByKey)
print("\n--- Version 3 (reduceByKey) ---")
mean_v3 = (
test_rdd
.map(lambda row: (row['brand_code'], (row['price'], 1)))
.reduceByKey(lambda x, y: (x[0] + y[0], x[1] + y[1]))
.mapValues(lambda x: round(x[0] / x[1], 2))
.sortByKey()
.collect()
)
print(f"V3 Results: {mean_v3}")
print("Expected: [('A', 150.0), ('B', 350.0)]")
# Verify correctness
assert mean_v1 == mean_v3 == [('A', 150.0), ('B', 350.0)], "Mean calculation mismatch"
print("\n[PASS] Mean functions produce identical correct results")
# ============================================
# Test Join Functions
# ============================================
print("\n" + "="*50)
print("Testing Join Functions (Shuffle vs Hash Join)")
print("="*50)
left = spark.sparkContext.parallelize([
Row(id=1, name='A'),
Row(id=2, name='B'),
Row(id=3, name='C')
])
right = spark.sparkContext.parallelize([
Row(id=1, value=10),
Row(id=2, value=20)
])
# Shuffle Join
print("\n--- Shuffle Join Test ---")
shuffle_result = shuffle_join(left, right, 'id', 'id').sortBy(lambda r: r.id).collect()
print(f"Shuffle Join: {len(shuffle_result)} rows")
for row in shuffle_result:
print(f" {row}")
# Hash Join
print("\n--- Replicated Hash Join Test ---")
hash_result = replicated_hash_join(left, right, 'id', 'id').sortBy(lambda r: r.id).collect()
print(f"Hash Join: {len(hash_result)} rows")
for row in hash_result:
print(f" {row}")
# Verify correctness
assert len(shuffle_result) == len(hash_result) == 2, "Join count mismatch"
print("\n[PASS] Both join implementations produce 2 rows as expected")
# Verify content
expected_ids = {1, 2}
shuffle_ids = {r.id for r in shuffle_result}
hash_ids = {r.id for r in hash_result}
assert shuffle_ids == hash_ids == expected_ids, "Join content mismatch"
print("[PASS] Both joins match on ids {1, 2}")
# ============================================
# Summary
# ============================================
print("\n" + "="*50)
print("ALL SELF-CHECKS PASSED")
print("="*50)
print("* Mean V1 (groupByKey) - PASSED")
print("* Mean V3 (reduceByKey) - PASSED")
print("* Shuffle Join - PASSED")
print("* Replicated Hash Join - PASSED")
Restarting Spark...
Spark restarted successfully
==================================================
Testing Mean Functions (Version 1 vs Version 3)
==================================================
--- Version 1 (groupByKey) ---
V1 Results: [('A', 150.0), ('B', 350.0)]
Expected: [('A', 150.0), ('B', 350.0)]
--- Version 3 (reduceByKey) ---
V3 Results: [('A', 150.0), ('B', 350.0)]
Expected: [('A', 150.0), ('B', 350.0)]
[PASS] Mean functions produce identical correct results
==================================================
Testing Join Functions (Shuffle vs Hash Join)
==================================================
--- Shuffle Join Test ---
Shuffle Join: 2 rows
Row(id=1, name='A', value=10)
Row(id=2, name='B', value=20)
--- Replicated Hash Join Test ---
Hash Join: 2 rows
Row(id=1, name='A', value=10)
Row(id=2, name='B', value=20)
[PASS] Both join implementations produce 2 rows as expected
[PASS] Both joins match on ids {1, 2}
==================================================
ALL SELF-CHECKS PASSED
==================================================
* Mean V1 (groupByKey) - PASSED
* Mean V3 (reduceByKey) - PASSED
* Shuffle Join - PASSED
* Replicated Hash Join - PASSED
Reproducibility checklist
- Record Python, Java, and Spark versions.
- Fix timezone to UTC and log run timestamp.
- Pin random seeds where randomness is used.
- Save configs:
spark.sql.shuffle.partitions, AQE flags, broadcast thresholds if changed. - Provide exact run commands and input/output paths.
- Export a minimal environment file (
environment.ymlorrequirements.txt). - Keep data paths relative to project root; avoid user-specific absolute paths.
- Include small sample outputs for verification.
## Liste de vérification complète pour la reproductibilité
import sys, os, platform, subprocess, random
import numpy as np
from datetime import datetime, timezone
print("REPRODUCTIBILITÉ ")
# 1. Versions des logiciels
print("\n[1] VERSIONS DES LOGICIELS")
print("-" * 70)
print(f"Python: {sys.version.split()[0]}")
print(f"Spark: {spark.version}")
try:
java_ver = subprocess.check_output(['java', '-version'], stderr=subprocess.STDOUT).decode().split('\n')[0]
print(f"Java: {java_ver}")
except:
print("Java: Version non détectée")
print(f"OS: {platform.system()} {platform.release()} ({platform.machine()})")
# 2. Horodatage UTC
print("\n[2] HORODATAGE D'EXÉCUTION")
print("-" * 70)
os.environ['TZ'] = 'UTC'
execution_time = datetime.now(timezone.utc).isoformat()
print(f"Démarrage: {execution_time} UTC")
# 3. Graine aléatoire
print("\n[3] GRAINE ALÉATOIRE")
print("-" * 70)
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
print(f"Graine fixée à: {SEED}")
# 4. Configuration Spark
print("\n[4] CONFIGURATION SPARK")
print("-" * 70)
configs = {
"spark.sql.shuffle.partitions": spark.conf.get("spark.sql.shuffle.partitions"),
"spark.sql.adaptive.enabled": spark.conf.get("spark.sql.adaptive.enabled"),
"spark.driver.memory": spark.conf.get("spark.driver.memory"),
}
for key, val in configs.items():
print(f"{key}: {val}")
# 5. Chemins
print("\n[5] CHEMINS DE DONNÉES")
print("-" * 70)
print(f"Chemin des données: {data_path}")
print(f"Existe: {os.path.exists(data_path)}")
# 6. Packages critiques
print("\n[6] PACKAGES ESSENTIELS")
print("-" * 70)
for pkg in ["pyspark", "pandas", "numpy", "matplotlib"]:
try:
ver = subprocess.check_output([sys.executable, "-m", "pip", "show", pkg]).decode()
ver_line = [l for l in ver.split('\n') if l.startswith('Version:')][0]
print(f"{pkg}: {ver_line.split(':')[1].strip()}")
except:
print(f"{pkg}: Non installé")
print("\n" + "="*70)
print("Toutes les informations de reproductibilité ont été enregistrées")
print("="*70)REPRODUCTIBILITÉ
[1] VERSIONS DES LOGICIELS
----------------------------------------------------------------------
Python: 3.10.18
Spark: 4.0.1
Java: openjdk version "17.0.15-internal" 2025-04-15
OS: Linux 6.14.0-37-generic (x86_64)
[2] HORODATAGE D'EXÉCUTION
----------------------------------------------------------------------
Démarrage: 2025-12-19T22:38:38.110215+00:00 UTC
[3] GRAINE ALÉATOIRE
----------------------------------------------------------------------
Graine fixée à: 42
[4] CONFIGURATION SPARK
----------------------------------------------------------------------
spark.sql.shuffle.partitions: 400
spark.sql.adaptive.enabled: true
spark.driver.memory: 4g
[5] CHEMINS DE DONNÉES
----------------------------------------------------------------------
Chemin des données: /home/sable/devops_base/td2/retail_dw_20250826
Existe: True
[6] PACKAGES ESSENTIELS
----------------------------------------------------------------------
pyspark: 4.0.1
pandas: 2.3.3
numpy: 2.2.6
matplotlib: 3.10.8
======================================================================
Toutes les informations de reproductibilité ont été enregistrées
======================================================================