from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, StructType
from pyspark.sql.functions import col, explode_outer
def flatten(df):
"""
Recursively flattens a PySpark DataFrame with nested structures.
For any column whose type is either ArrayType or StructType:
- If it is a StructType, the function expands each field of the struct into a new column.
The new column names are in the form "parentField_childField".
- If it is an ArrayType, the function uses `explode_outer` to convert each element of
the array into a separate row (which is useful for arrays of structs or primitive types).
Parameters:
df (DataFrame): The input DataFrame that may contain nested columns.
Returns:
DataFrame: A flattened DataFrame with no nested columns.
"""
# Identify columns whose type is either ArrayType or StructType.
complex_fields = {field.name: field.dataType
for field in df.schema.fields
if isinstance(field.dataType, (ArrayType, StructType))}
while complex_fields:
for col_name, col_type in complex_fields.items():
if isinstance(col_type, StructType):
# For a struct: expand its fields as separate columns.
expanded = [
col(col_name + '.' + subfield.name).alias(col_name + '_' + subfield.name)
for subfield in col_type.fields
]
df = df.select("*", *expanded).drop(col_name)
elif isinstance(col_type, ArrayType):
# For an array, explode it so that each element becomes a separate row.
df = df.withColumn(col_name, explode_outer(col_name))
# Recompute the schema to check whether more nested columns remain.
complex_fields = {field.name: field.dataType
for field in df.schema.fields
if isinstance(field.dataType, (ArrayType, StructType))}
return df
Example Usage
if __name__ == "__main__":
spark = SparkSession.builder.appName("FlattenNestedJson").getOrCreate()
# Replace this with the path to your nested employee JSON file.
json_file_path = "/path/to/employee_record.json"
# Read the nested JSON file.
df = spark.read.json(json_file_path)
# Apply the flatten function.
flat_df = flatten(df)
# Display the flattened DataFrame.
flat_df.show(truncate=False)
spark.stop()
Detecting Complex Types: The function first builds a dictionary (complex_fields) of column names pointing to their data types for any field that is a nested structure (either an array or a struct).
Processing Structs: For each field of type StructType, the code iterates over its subfields and creates new columns named in the pattern "parent_subfield". The original nested column is then dropped.
Processing Arrays: For columns with ArrayType, the function calls explode_outer which converts each element of the array into a separate row (ensuring that null or empty arrays are handled gracefully).
Iterative Flattening: After processing the current set of nested fields, the function rebuilds the dictionary to catch any newly exposed nested fields. This loop continues until no more complex types remain.
Subscribe to:
Post Comments (Atom)
Data synchronization in Lakehouse
Data synchronization in Lakebase ensures that transactional data and analytical data remain up-to-date across the lakehouse and Postgres d...
-
Steps to Implement Medallion Architecture : Ingest Data into the Bronze Layer : Load raw data from external sources (e.g., databases, AP...
-
Databricks Platform Architecture The Databricks platform architecture consists of two main components: the Control Plane and the Data Pla...
-
from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, StructType from pyspark.sql.functions import col, explode_o...
No comments:
Post a Comment