Thursday, April 24, 2025

How to flatten a complex JSON file - Example 2

from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, StructType
from pyspark.sql.functions import col, explode_outer

def flatten(df):

""" Recursively flattens a PySpark DataFrame with nested structures. For any column whose type is either ArrayType or StructType: - If it is a StructType, the function expands each field of the struct into a new column. The new column names are in the form "parentField_childField". - If it is an ArrayType, the function uses `explode_outer` to convert each element of the array into a separate row (which is useful for arrays of structs or primitive types).

Parameters: df (DataFrame): The input DataFrame that may contain nested columns.

Returns: DataFrame: A flattened DataFrame with no nested columns. """

# Identify columns whose type is either ArrayType or StructType.

complex_fields = {field.name: field.dataType for field in df.schema.fields if isinstance(field.dataType, (ArrayType, StructType))}

while complex_fields:
for col_name, col_type in complex_fields.items():
if isinstance(col_type, StructType):
# For a struct: expand its fields as separate columns.
expanded = [ col(col_name + '.' + subfield.name).alias(col_name + '_' + subfield.name) for subfield in col_type.fields ]

df = df.select("*", *expanded).drop(col_name)
elif isinstance(col_type, ArrayType):
# For an array, explode it so that each element becomes a separate row.
df = df.withColumn(col_name, explode_outer(col_name))
# Recompute the schema to check whether more nested columns remain.

complex_fields = {field.name: field.dataType for field in df.schema.fields if isinstance(field.dataType, (ArrayType, StructType))}

return df
Example Usage
if __name__ == "__main__":
spark = SparkSession.builder.appName("FlattenNestedJson").getOrCreate()

# Replace this with the path to your nested employee JSON file.
json_file_path = "/path/to/employee_record.json"

# Read the nested JSON file.
df = spark.read.json(json_file_path)
# Apply the flatten function.
flat_df = flatten(df)
# Display the flattened DataFrame.
flat_df.show(truncate=False)
spark.stop()



Detecting Complex Types: The function first builds a dictionary (complex_fields) of column names pointing to their data types for any field that is a nested structure (either an array or a struct).

Processing Structs: For each field of type StructType, the code iterates over its subfields and creates new columns named in the pattern "parent_subfield". The original nested column is then dropped.

Processing Arrays: For columns with ArrayType, the function calls explode_outer which converts each element of the array into a separate row (ensuring that null or empty arrays are handled gracefully).

Iterative Flattening: After processing the current set of nested fields, the function rebuilds the dictionary to catch any newly exposed nested fields. This loop continues until no more complex types remain.

No comments:

Post a Comment

Data synchronization in Lakehouse

Data synchronization in Lakebase ensures that transactional data and analytical data remain up-to-date across the lakehouse and Postgres d...