Loading...
Loading...
Complete guide for Apache Spark data processing including RDDs, DataFrames, Spark SQL, streaming, MLlib, and production deployment
npx skill4agent add manutej/luxor-claude-marketplace apache-spark-data-processingrepartition()coalesce()# Cache DataFrame in memory
df.cache() # Shorthand for persist(StorageLevel.MEMORY_AND_DISK)
# Different storage levels
df.persist(StorageLevel.MEMORY_ONLY) # Fast but may lose data if evicted
df.persist(StorageLevel.MEMORY_AND_DISK) # Spill to disk if memory full
df.persist(StorageLevel.DISK_ONLY) # Store only on disk
df.persist(StorageLevel.MEMORY_ONLY_SER) # Serialized in memory (more compact)
# Unpersist when done
df.unpersist()# Broadcast a lookup table
lookup_table = {"key1": "value1", "key2": "value2"}
broadcast_lookup = sc.broadcast(lookup_table)
# Use in transformations
rdd.map(lambda x: broadcast_lookup.value.get(x, "default"))# Create accumulator
error_count = sc.accumulator(0)
# Increment in tasks
rdd.foreach(lambda x: error_count.add(1) if is_error(x) else None)
# Read final value in driver
print(f"Total errors: {error_count.value}")from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("SparkSQLExample").getOrCreate()
# From structured data
data = [("Alice", 1), ("Bob", 2), ("Charlie", 3)]
columns = ["name", "id"]
df = spark.createDataFrame(data, columns)
# From files
df_json = spark.read.json("path/to/file.json")
df_parquet = spark.read.parquet("path/to/file.parquet")
df_csv = spark.read.option("header", "true").csv("path/to/file.csv")
# From JDBC sources
df_jdbc = spark.read \
.format("jdbc") \
.option("url", "jdbc:postgresql://host:port/database") \
.option("dbtable", "table_name") \
.option("user", "username") \
.option("password", "password") \
.load()# Select columns
df.select("name", "age").show()
# Filter rows
df.filter(df.age > 21).show()
df.where(df["age"] > 21).show() # Alternative syntax
# Add/modify columns
from pyspark.sql.functions import col, lit
df.withColumn("age_plus_10", col("age") + 10).show()
df.withColumn("country", lit("USA")).show()
# Aggregations
df.groupBy("department").count().show()
df.groupBy("department").agg({"salary": "avg", "age": "max"}).show()
# Sorting
df.orderBy("age").show()
df.orderBy(col("age").desc()).show()
# Joins
df1.join(df2, df1.id == df2.user_id, "inner").show()
df1.join(df2, "id", "left_outer").show()
# Unions
df1.union(df2).show()# Register DataFrame as temporary view
df.createOrReplaceTempView("people")
# Run SQL queries
sql_result = spark.sql("SELECT name FROM people WHERE age > 21")
sql_result.show()
# Complex queries
result = spark.sql("""
SELECT
department,
COUNT(*) as employee_count,
AVG(salary) as avg_salary,
MAX(age) as max_age
FROM people
WHERE age > 25
GROUP BY department
HAVING COUNT(*) > 5
ORDER BY avg_salary DESC
""")
result.show()# Write
df.write.parquet("output/path", mode="overwrite", compression="snappy")
# Read with partition pruning
df = spark.read.parquet("output/path").filter(col("date") == "2025-01-01")df.write.orc("output/path", mode="overwrite")
df = spark.read.orc("output/path")# Read with explicit schema
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
schema = StructType([
StructField("name", StringType(), True),
StructField("age", IntegerType(), True)
])
df = spark.read.schema(schema).json("data.json")df.write.csv("output.csv", header=True, mode="overwrite")
df = spark.read.option("header", "true").option("inferSchema", "true").csv("data.csv")from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, rank, dense_rank, lag, lead, sum, avg
# Define window specification
window_spec = Window.partitionBy("department").orderBy(col("salary").desc())
# Ranking functions
df.withColumn("rank", rank().over(window_spec)).show()
df.withColumn("row_num", row_number().over(window_spec)).show()
df.withColumn("dense_rank", dense_rank().over(window_spec)).show()
# Aggregate functions over window
df.withColumn("dept_avg_salary", avg("salary").over(window_spec)).show()
df.withColumn("running_total", sum("salary").over(window_spec.rowsBetween(Window.unboundedPreceding, Window.currentRow))).show()
# Offset functions
df.withColumn("prev_salary", lag("salary", 1).over(window_spec)).show()
df.withColumn("next_salary", lead("salary", 1).over(window_spec)).show()from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, IntegerType
# Python UDF (slower due to serialization overhead)
def categorize_age(age):
if age < 18:
return "Minor"
elif age < 65:
return "Adult"
else:
return "Senior"
categorize_udf = udf(categorize_age, StringType())
df.withColumn("age_category", categorize_udf(col("age"))).show()
# Pandas UDF (vectorized, faster for large datasets)
from pyspark.sql.functions import pandas_udf
import pandas as pd
@pandas_udf(IntegerType())
def square(series: pd.Series) -> pd.Series:
return series ** 2
df.withColumn("age_squared", square(col("age"))).show()# RDD
rdd = sc.parallelize([1, 2, 3, 4, 5])
squared = rdd.map(lambda x: x * 2) # [2, 4, 6, 8, 10]
# DataFrame (use select with functions)
from pyspark.sql.functions import col
df.select(col("value") * 2).show()# RDD
rdd.filter(lambda x: x > 2).collect() # [3, 4, 5]
# DataFrame
df.filter(col("age") > 25).show()# RDD - Split text into words
lines = sc.parallelize(["hello world", "apache spark"])
words = lines.flatMap(lambda line: line.split(" ")) # ["hello", "world", "apache", "spark"]# Word count example
words = sc.parallelize(["apple", "banana", "apple", "cherry", "banana", "apple"])
word_pairs = words.map(lambda word: (word, 1))
word_counts = word_pairs.reduceByKey(lambda a, b: a + b)
# Result: [("apple", 3), ("banana", 2), ("cherry", 1)]# Less efficient than reduceByKey
word_pairs.groupByKey().mapValues(list).collect()
# Result: [("apple", [1, 1, 1]), ("banana", [1, 1]), ("cherry", [1])]# RDD join
users = sc.parallelize([("user1", "Alice"), ("user2", "Bob")])
orders = sc.parallelize([("user1", 100), ("user2", 200), ("user1", 150)])
users.join(orders).collect()
# Result: [("user1", ("Alice", 100)), ("user1", ("Alice", 150)), ("user2", ("Bob", 200))]
# DataFrame join (more efficient)
df_users.join(df_orders, "user_id", "inner").show()# RDD
rdd.distinct().collect()
# DataFrame
df.distinct().show()
df.dropDuplicates(["user_id"]).show() # Drop based on specific columns# Reduce partitions (no shuffle, more efficient)
df.coalesce(1).write.parquet("output")
# Increase/decrease partitions (involves shuffle)
df.repartition(10).write.parquet("output")
df.repartition(10, "user_id").write.parquet("output") # Partition by columnresults = rdd.collect() # Returns list
# WARNING: Only use on small datasets that fit in driver memorytotal = df.count() # Number of rowsfirst_elem = rdd.first()
first_five = rdd.take(5)total_sum = rdd.reduce(lambda a, b: a + b)# Side effects only (no return value)
rdd.foreach(lambda x: print(x))rdd.saveAsTextFile("hdfs://path/to/output")df.show(20, truncate=False) # Show 20 rows, don't truncate columnsfrom pyspark.sql import SparkSession
from pyspark.sql.functions import col
spark = SparkSession.builder.appName("StreamingExample").getOrCreate()
# Read stream from JSON files
input_stream = spark.readStream \
.format("json") \
.schema(schema) \
.option("maxFilesPerTrigger", 1) \
.load("input/directory")
# Transform streaming data
processed = input_stream \
.filter(col("value") > 10) \
.select("id", "value", "timestamp")
# Write stream to Parquet
query = processed.writeStream \
.format("parquet") \
.option("path", "output/directory") \
.option("checkpointLocation", "checkpoint/directory") \
.outputMode("append") \
.start()
# Wait for termination
query.awaitTermination()# Static DataFrame (loaded once)
static_df = spark.read.parquet("reference/data")
# Streaming DataFrame
streaming_df = spark.readStream.format("kafka").load()
# Inner join (supported)
joined = streaming_df.join(static_df, "type")
# Left outer join (supported)
joined = streaming_df.join(static_df, "type", "left_outer")
# Write result
joined.writeStream \
.format("parquet") \
.option("path", "output") \
.option("checkpointLocation", "checkpoint") \
.start()from pyspark.sql.functions import window, col, count
# 10-minute tumbling window
windowed_counts = streaming_df \
.groupBy(
window(col("timestamp"), "10 minutes"),
col("word")
) \
.count()
# 10-minute sliding window with 5-minute slide
windowed_counts = streaming_df \
.groupBy(
window(col("timestamp"), "10 minutes", "5 minutes"),
col("word")
) \
.count()
# Write to console for debugging
query = windowed_counts.writeStream \
.outputMode("complete") \
.format("console") \
.option("truncate", "false") \
.start()from pyspark.sql.functions import window
# Define watermark (10 minutes tolerance for late data)
windowed_counts = streaming_df \
.withWatermark("timestamp", "10 minutes") \
.groupBy(
window(col("timestamp"), "10 minutes"),
col("word")
) \
.count()
# Data arriving more than 10 minutes late will be droppedfrom pyspark.sql.functions import session_window, when
# Dynamic session window based on user
session_window_spec = session_window(
col("timestamp"),
when(col("userId") == "user1", "5 seconds")
.when(col("userId") == "user2", "20 seconds")
.otherwise("5 minutes")
)
sessionized_counts = streaming_df \
.withWatermark("timestamp", "10 minutes") \
.groupBy(session_window_spec, col("userId")) \
.count()from pyspark.sql.functions import expr
# Deduplication using state
deduplicated = streaming_df \
.withWatermark("timestamp", "1 hour") \
.dropDuplicates(["user_id", "event_id"])
# Stream-stream joins (stateful)
stream1 = spark.readStream.format("kafka").option("subscribe", "topic1").load()
stream2 = spark.readStream.format("kafka").option("subscribe", "topic2").load()
joined = stream1 \
.withWatermark("timestamp", "10 minutes") \
.join(
stream2.withWatermark("timestamp", "20 minutes"),
expr("stream1.user_id = stream2.user_id AND stream1.timestamp >= stream2.timestamp AND stream1.timestamp <= stream2.timestamp + interval 15 minutes"),
"inner"
)# Checkpoint location stores:
# - Stream metadata (offsets, configuration)
# - State information (for stateful operations)
# - Write-ahead logs
query = streaming_df.writeStream \
.format("parquet") \
.option("path", "output") \
.option("checkpointLocation", "checkpoint/dir") # REQUIRED for production \
.start()
# Recovery: Restart query with same checkpoint location
# Spark will resume from last committed offsetfrom pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression
# Load data
df = spark.read.format("libsvm").load("data/sample_libsvm_data.txt")
# Define pipeline stages
assembler = VectorAssembler(
inputCols=["feature1", "feature2", "feature3"],
outputCol="features"
)
scaler = StandardScaler(
inputCol="features",
outputCol="scaled_features",
withStd=True,
withMean=True
)
lr = LogisticRegression(
featuresCol="scaled_features",
labelCol="label",
maxIter=10,
regParam=0.01
)
# Create pipeline
pipeline = Pipeline(stages=[assembler, scaler, lr])
# Split data
train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)
# Train model
model = pipeline.fit(train_df)
# Make predictions
predictions = model.transform(test_df)
predictions.select("label", "prediction", "probability").show()from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, MinMaxScaler
# Categorical encoding
indexer = StringIndexer(inputCol="category", outputCol="category_index")
encoder = OneHotEncoder(inputCol="category_index", outputCol="category_vec")
# Numerical scaling
scaler = MinMaxScaler(inputCol="features", outputCol="scaled_features")
# Assemble features
assembler = VectorAssembler(
inputCols=["category_vec", "numeric_feature1", "numeric_feature2"],
outputCol="features"
)
# Text processing
from pyspark.ml.feature import Tokenizer, HashingTF, IDF
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashing_tf = HashingTF(inputCol="words", outputCol="raw_features", numFeatures=10000)
idf = IDF(inputCol="raw_features", outputCol="features")from pyspark.mllib.regression import LabeledPoint
from pyspark.streaming import StreamingContext
from pyspark.streaming.ml import StreamingLinearRegressionWithSGD
# Create StreamingContext
ssc = StreamingContext(sc, batchDuration=1)
# Define data streams
training_stream = ssc.textFileStream("training/data/path")
testing_stream = ssc.textFileStream("testing/data/path")
# Parse streams into LabeledPoint objects
def parse_point(line):
values = [float(x) for x in line.strip().split(',')]
return LabeledPoint(values[0], values[1:])
parsed_training = training_stream.map(parse_point)
parsed_testing = testing_stream.map(parse_point)
# Initialize model
num_features = 3
model = StreamingLinearRegressionWithSGD(initialWeights=[0.0] * num_features)
# Train and predict
model.trainOn(parsed_training)
predictions = model.predictOnValues(parsed_testing.map(lambda lp: (lp.label, lp.features)))
# Print predictions
predictions.pprint()
# Start streaming
ssc.start()
ssc.awaitTermination()from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator, RegressionEvaluator
# Binary classification
binary_evaluator = BinaryClassificationEvaluator(
labelCol="label",
rawPredictionCol="rawPrediction",
metricName="areaUnderROC"
)
auc = binary_evaluator.evaluate(predictions)
print(f"AUC: {auc}")
# Multiclass classification
multi_evaluator = MulticlassClassificationEvaluator(
labelCol="label",
predictionCol="prediction",
metricName="accuracy"
)
accuracy = multi_evaluator.evaluate(predictions)
print(f"Accuracy: {accuracy}")
# Regression
regression_evaluator = RegressionEvaluator(
labelCol="label",
predictionCol="prediction",
metricName="rmse"
)
rmse = regression_evaluator.evaluate(predictions)
print(f"RMSE: {rmse}")from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# Define model
rf = RandomForestClassifier(labelCol="label", featuresCol="features")
# Build parameter grid
param_grid = ParamGridBuilder() \
.addGrid(rf.numTrees, [10, 20, 50]) \
.addGrid(rf.maxDepth, [5, 10, 15]) \
.addGrid(rf.minInstancesPerNode, [1, 5, 10]) \
.build()
# Define evaluator
evaluator = MulticlassClassificationEvaluator(metricName="accuracy")
# Cross-validation
cv = CrossValidator(
estimator=rf,
estimatorParamMaps=param_grid,
evaluator=evaluator,
numFolds=5,
parallelism=4
)
# Train
cv_model = cv.fit(train_df)
# Best model
best_model = cv_model.bestModel
print(f"Best numTrees: {best_model.getNumTrees}")
print(f"Best maxDepth: {best_model.getMaxDepth()}")
# Evaluate on test set
predictions = cv_model.transform(test_df)
accuracy = evaluator.evaluate(predictions)
print(f"Test Accuracy: {accuracy}")from pyspark.mllib.linalg.distributed import RowMatrix, IndexedRowMatrix, CoordinateMatrix
from pyspark.mllib.linalg import Vectors
# RowMatrix: Distributed matrix without row indices
rows = sc.parallelize([
Vectors.dense([1.0, 2.0, 3.0]),
Vectors.dense([4.0, 5.0, 6.0]),
Vectors.dense([7.0, 8.0, 9.0])
])
row_matrix = RowMatrix(rows)
# Compute statistics
print(f"Rows: {row_matrix.numRows()}")
print(f"Cols: {row_matrix.numCols()}")
print(f"Column means: {row_matrix.computeColumnSummaryStatistics().mean()}")
# IndexedRowMatrix: Matrix with row indices
from pyspark.mllib.linalg.distributed import IndexedRow
indexed_rows = sc.parallelize([
IndexedRow(0, Vectors.dense([1.0, 2.0, 3.0])),
IndexedRow(1, Vectors.dense([4.0, 5.0, 6.0]))
])
indexed_matrix = IndexedRowMatrix(indexed_rows)
# CoordinateMatrix: Sparse matrix using (row, col, value) entries
from pyspark.mllib.linalg.distributed import MatrixEntry
entries = sc.parallelize([
MatrixEntry(0, 0, 1.0),
MatrixEntry(0, 2, 3.0),
MatrixEntry(1, 1, 5.0)
])
coord_matrix = CoordinateMatrix(entries)# Scala/Java approach
data = [("a", 1), ("b", 2), ("a", 3), ("b", 4), ("a", 5), ("c", 6)]
rdd = sc.parallelize(data)
# Define sampling fractions per key
fractions = {"a": 0.5, "b": 0.5, "c": 0.5}
# Approximate sample (faster, one pass)
sampled_rdd = rdd.sampleByKey(withReplacement=False, fractions=fractions)
# Exact sample (slower, guaranteed exact counts)
exact_sampled = rdd.sampleByKeyExact(withReplacement=False, fractions=fractions)
print(sampled_rdd.collect())spark = SparkSession.builder \
.appName("MemoryTuning") \
.config("spark.executor.memory", "4g") \
.config("spark.driver.memory", "2g") \
.config("spark.memory.fraction", "0.6") # Fraction for execution + storage \
.config("spark.memory.storageFraction", "0.5") # Fraction of above for storage \
.getOrCreate()# 1. Use reduceByKey instead of groupByKey
# Bad: groupByKey shuffles all data
word_pairs.groupByKey().mapValues(sum)
# Good: reduceByKey combines locally before shuffle
word_pairs.reduceByKey(lambda a, b: a + b)
# 2. Broadcast small tables in joins
from pyspark.sql.functions import broadcast
large_df.join(broadcast(small_df), "key")
# 3. Partition data appropriately
df.repartition(200, "user_id") # Partition by key for subsequent aggregations
# 4. Coalesce instead of repartition when reducing partitions
df.coalesce(10) # No shuffle, just merge partitions
# 5. Tune shuffle partitions
spark.conf.set("spark.sql.shuffle.partitions", 200) # Default is 200spark = SparkSession.builder \
.config("spark.sql.shuffle.partitions", 200) \
.config("spark.default.parallelism", 200) \
.config("spark.sql.adaptive.enabled", "true") # Enable AQE \
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
.getOrCreate()# Partition writes by date for easy filtering
df.write.partitionBy("date", "country").parquet("output")
# Read with partition pruning (only reads relevant partitions)
spark.read.parquet("output").filter(col("date") == "2025-01-15").show()from pyspark.rdd import portable_hash
# Custom partitioner for RDD
def custom_partitioner(key):
return portable_hash(key) % 100
rdd.partitionBy(100, custom_partitioner)# Iterative algorithms (ML)
training_data.cache()
for i in range(num_iterations):
model = train_model(training_data)
# Multiple aggregations on same data
base_df.cache()
result1 = base_df.groupBy("country").count()
result2 = base_df.groupBy("city").avg("sales")
# Interactive analysis
df.cache()
df.filter(condition1).show()
df.filter(condition2).show()
df.groupBy("category").count().show()from pyspark import StorageLevel
# Memory only (fastest, but may lose data)
df.persist(StorageLevel.MEMORY_ONLY)
# Memory and disk (spill to disk if needed)
df.persist(StorageLevel.MEMORY_AND_DISK)
# Serialized in memory (more compact, slower access)
df.persist(StorageLevel.MEMORY_ONLY_SER)
# Disk only (slowest, but always available)
df.persist(StorageLevel.DISK_ONLY)
# Replicated (fault tolerance)
df.persist(StorageLevel.MEMORY_AND_DISK_2) # 2 replicasfrom pyspark.sql.functions import broadcast
# Automatic broadcast (tables < spark.sql.autoBroadcastJoinThreshold)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 10 * 1024 * 1024) # 10 MB
# Explicit broadcast hint
large_df.join(broadcast(small_df), "key")
# Benefits:
# - No shuffle of large table
# - Small table sent to all executors once
# - Much faster for small dimension tablesspark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# AQE Benefits:
# - Dynamically coalesce partitions after shuffle
# - Handle skewed joins by splitting large partitions
# - Optimize join strategy at runtime# View physical plan
df.explain(mode="extended")
# Optimizations include:
# - Predicate pushdown: Push filters to data source
# - Column pruning: Read only required columns
# - Constant folding: Evaluate constants at compile time
# - Join reordering: Optimize join order
# - Partition pruning: Skip irrelevant partitions# Start master
$SPARK_HOME/sbin/start-master.sh
# Start workers
$SPARK_HOME/sbin/start-worker.sh spark://master:7077
# Submit application
spark-submit --master spark://master:7077 app.py# Cluster mode (driver runs on YARN)
spark-submit --master yarn --deploy-mode cluster app.py
# Client mode (driver runs locally)
spark-submit --master yarn --deploy-mode client app.pyspark-submit \
--master k8s://https://k8s-master:443 \
--deploy-mode cluster \
--name spark-app \
--conf spark.executor.instances=5 \
--conf spark.kubernetes.container.image=spark:latest \
app.pyspark-submit \
--master yarn \
--deploy-mode cluster \
--driver-memory 4g \
--executor-memory 8g \
--executor-cores 4 \
--num-executors 10 \
--conf spark.sql.shuffle.partitions=200 \
--py-files dependencies.zip \
--files config.json \
application.py--master--deploy-mode--driver-memory--executor-memory--executor-cores--num-executors--conf--py-files--filesCluster: 10 nodes, 32 cores each, 128 GB RAM each
Option 1: Many small executors
- 30 executors (3 per node)
- 10 cores per executor
- 40 GB memory per executor
- Total: 300 cores
Option 2: Fewer large executors (RECOMMENDED)
- 50 executors (5 per node)
- 5 cores per executor
- 24 GB memory per executor
- Total: 250 coresspark = SparkSession.builder \
.appName("DynamicAllocation") \
.config("spark.dynamicAllocation.enabled", "true") \
.config("spark.dynamicAllocation.minExecutors", 2) \
.config("spark.dynamicAllocation.maxExecutors", 100) \
.config("spark.dynamicAllocation.initialExecutors", 10) \
.config("spark.dynamicAllocation.executorIdleTimeout", "60s") \
.getOrCreate()# Start history server
$SPARK_HOME/sbin/start-history-server.sh
# Configure event logging
spark.conf.set("spark.eventLog.enabled", "true")
spark.conf.set("spark.eventLog.dir", "hdfs://namenode/spark-logs")# Enable metrics collection
spark.conf.set("spark.metrics.conf.*.sink.console.class", "org.apache.spark.metrics.sink.ConsoleSink")
spark.conf.set("spark.metrics.conf.*.sink.console.period", 10)# Configure log level
spark.sparkContext.setLogLevel("WARN") # ERROR, WARN, INFO, DEBUG
# Custom logging
import logging
logger = logging.getLogger(__name__)
logger.info("Custom log message")# Set checkpoint directory
spark.sparkContext.setCheckpointDir("hdfs://namenode/checkpoints")
# Checkpoint RDD (breaks lineage for very long chains)
rdd.checkpoint()
# Streaming checkpoint (required for production)
query = streaming_df.writeStream \
.option("checkpointLocation", "hdfs://namenode/streaming-checkpoint") \
.start()# Enable speculative execution for slow tasks
spark.conf.set("spark.speculation", "true")
spark.conf.set("spark.speculation.multiplier", 1.5)
spark.conf.set("spark.speculation.quantile", 0.75)# Increase locality wait time
spark.conf.set("spark.locality.wait", "10s")
spark.conf.set("spark.locality.wait.node", "5s")
spark.conf.set("spark.locality.wait.rack", "3s")
# Partition data to match cluster topology
df.repartition(num_nodes * cores_per_node)def etl_pipeline(spark, input_path, output_path):
# Extract
raw_df = spark.read.parquet(input_path)
# Transform
cleaned_df = raw_df \
.dropDuplicates(["id"]) \
.filter(col("value").isNotNull()) \
.withColumn("processed_date", current_date())
# Enrich
enriched_df = cleaned_df.join(broadcast(reference_df), "key")
# Aggregate
aggregated_df = enriched_df \
.groupBy("category", "date") \
.agg(
count("*").alias("count"),
sum("amount").alias("total_amount"),
avg("value").alias("avg_value")
)
# Load
aggregated_df.write \
.partitionBy("date") \
.mode("overwrite") \
.parquet(output_path)def incremental_process(spark, input_path, output_path, checkpoint_path):
# Read last processed timestamp
last_timestamp = read_checkpoint(checkpoint_path)
# Read new data
new_data = spark.read.parquet(input_path) \
.filter(col("timestamp") > last_timestamp)
# Process
processed = transform(new_data)
# Write
processed.write.mode("append").parquet(output_path)
# Update checkpoint
max_timestamp = new_data.agg(max("timestamp")).collect()[0][0]
write_checkpoint(checkpoint_path, max_timestamp)def scd_type2_upsert(spark, dimension_df, updates_df):
# Mark existing records as inactive if updated
inactive_records = dimension_df \
.join(updates_df, "business_key") \
.select(
dimension_df["*"],
lit(False).alias("is_active"),
current_date().alias("end_date")
)
# Add new records
new_records = updates_df \
.withColumn("is_active", lit(True)) \
.withColumn("start_date", current_date()) \
.withColumn("end_date", lit(None))
# Union unchanged, inactive, and new records
result = dimension_df \
.join(updates_df, "business_key", "left_anti") \
.union(inactive_records) \
.union(new_records)
return resultdef calculate_running_metrics(df):
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, lag, sum, avg
# Define window
window_spec = Window.partitionBy("user_id").orderBy("timestamp")
# Calculate metrics
result = df \
.withColumn("row_num", row_number().over(window_spec)) \
.withColumn("prev_value", lag("value", 1).over(window_spec)) \
.withColumn("running_total", sum("value").over(window_spec.rowsBetween(Window.unboundedPreceding, Window.currentRow))) \
.withColumn("moving_avg", avg("value").over(window_spec.rowsBetween(-2, 0)))
return resultjava.lang.OutOfMemoryError# Increase executor memory
spark.conf.set("spark.executor.memory", "8g")
# Increase driver memory (if collecting data)
spark.conf.set("spark.driver.memory", "4g")
# Reduce memory pressure
df.persist(StorageLevel.MEMORY_AND_DISK) # Spill to disk
df.coalesce(100) # Reduce partition count
spark.conf.set("spark.sql.shuffle.partitions", 400) # Increase shuffle partitions
# Avoid collect() on large datasets
# Use take() or limit() instead
df.take(100)# Increase shuffle partitions
spark.conf.set("spark.sql.shuffle.partitions", 400)
# Handle skew with salting
df_salted = df.withColumn("salt", (rand() * 10).cast("int"))
result = df_salted.groupBy("key", "salt").agg(...)
# Use broadcast for small tables
large_df.join(broadcast(small_df), "key")
# Enable AQE for automatic optimization
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")# Increase executor memory for stateful operations
spark.conf.set("spark.executor.memory", "8g")
# Tune watermark for late data
.withWatermark("timestamp", "15 minutes")
# Increase trigger interval to reduce micro-batch overhead
.trigger(processingTime="30 seconds")
# Monitor lag and adjust parallelism
spark.conf.set("spark.sql.shuffle.partitions", 200)
# Recover from checkpoint corruption
# Delete checkpoint directory and restart (data loss possible)
# Or implement custom state recovery logic# 1. Salting technique (add random prefix to keys)
from pyspark.sql.functions import concat, lit, rand
df_salted = df.withColumn("salted_key", concat(col("key"), lit("_"), (rand() * 10).cast("int")))
result = df_salted.groupBy("salted_key").agg(...)
# 2. Repartition by skewed column
df.repartition(200, "skewed_column")
# 3. Isolate skewed keys
skewed_keys = df.groupBy("key").count().filter(col("count") > threshold).select("key")
skewed_df = df.join(broadcast(skewed_keys), "key")
normal_df = df.join(broadcast(skewed_keys), "key", "left_anti")
# Process separately
skewed_result = process_with_salting(skewed_df)
normal_result = process_normally(normal_df)
final = skewed_result.union(normal_result)