|
| 1 | +import datetime |
| 2 | +import re |
| 3 | +import sys |
| 4 | +import time |
| 5 | + |
| 6 | +from google.cloud import storage |
| 7 | +from py4j.protocol import Py4JJavaError |
| 8 | +from pyspark.sql import SparkSession |
| 9 | +from pyspark.sql.functions import UserDefinedFunction |
| 10 | +from pyspark.sql.types import FloatType, IntegerType, StringType |
| 11 | + |
| 12 | + |
| 13 | +def trip_duration_udf(duration): |
| 14 | + """Convert trip duration to seconds. Return None if negative.""" |
| 15 | + if not duration: |
| 16 | + return None |
| 17 | + |
| 18 | + time = re.match(r"\d*.\d*", duration) |
| 19 | + |
| 20 | + if not time: |
| 21 | + return None |
| 22 | + |
| 23 | + time = float(time[0]) |
| 24 | + |
| 25 | + if time < 0: |
| 26 | + return None |
| 27 | + |
| 28 | + if "m" in duration: |
| 29 | + time *= 60 |
| 30 | + elif "h" in duration: |
| 31 | + time *= 60 * 60 |
| 32 | + |
| 33 | + return int(time) |
| 34 | + |
| 35 | + |
| 36 | +def station_name_udf(name): |
| 37 | + """Replaces '/' with '&'.""" |
| 38 | + return name.replace("/", "&") if name else None |
| 39 | + |
| 40 | + |
| 41 | +def user_type_udf(user): |
| 42 | + """Converts user type to 'Subscriber' or 'Customer'.""" |
| 43 | + if not user: |
| 44 | + return None |
| 45 | + |
| 46 | + if user.lower().startswith("sub"): |
| 47 | + return "Subscriber" |
| 48 | + elif user.lower().startswith("cust"): |
| 49 | + return "Customer" |
| 50 | + |
| 51 | + |
| 52 | +def gender_udf(gender): |
| 53 | + """Converts gender to 'Male' or 'Female'.""" |
| 54 | + if not gender: |
| 55 | + return None |
| 56 | + |
| 57 | + if gender.lower().startswith("m"): |
| 58 | + return "Male" |
| 59 | + elif gender.lower().startswith("f"): |
| 60 | + return "Female" |
| 61 | + |
| 62 | + |
| 63 | +def angle_udf(angle): |
| 64 | + """Converts DMS notation to degrees. Return None if not in DMS or degrees notation.""" |
| 65 | + if not angle: |
| 66 | + return None |
| 67 | + |
| 68 | + dms = re.match(r'(-?\d*).(-?\d*)\'(-?\d*)"', angle) |
| 69 | + if dms: |
| 70 | + return int(dms[1]) + int(dms[2]) / 60 + int(dms[3]) / (60 * 60) |
| 71 | + |
| 72 | + degrees = re.match(r"\d*.\d*", angle) |
| 73 | + if degrees: |
| 74 | + return float(degrees[0]) |
| 75 | + |
| 76 | + |
| 77 | +def compute_time(duration, start, end): |
| 78 | + """Calculates duration, start time, and end time from each other if one value is null.""" |
| 79 | + time_format = "%Y-%m-%dT%H:%M:%S" |
| 80 | + |
| 81 | + # Transform to datetime objects |
| 82 | + if start: |
| 83 | + # Round to nearest second |
| 84 | + if "." in start: |
| 85 | + start = start[: start.index(".")] |
| 86 | + # Convert to datetime |
| 87 | + start = datetime.datetime.strptime(start, time_format) |
| 88 | + if end: |
| 89 | + # Round to nearest second |
| 90 | + if "." in end: |
| 91 | + end = end[: end.index(".")] |
| 92 | + # Convert to datetime |
| 93 | + end = datetime.datetime.strptime(end, time_format) |
| 94 | + if duration: |
| 95 | + # Convert to timedelta |
| 96 | + duration = datetime.timedelta(seconds=duration) |
| 97 | + |
| 98 | + # Calculate missing value |
| 99 | + if start and end and not duration: |
| 100 | + duration = end - start |
| 101 | + elif duration and end and not start: |
| 102 | + start = end - duration |
| 103 | + elif duration and start and not end: |
| 104 | + end = start + duration |
| 105 | + |
| 106 | + # Transform to primitive types |
| 107 | + if duration: |
| 108 | + duration = int(duration.total_seconds()) |
| 109 | + if start: |
| 110 | + start = start.strftime(time_format) |
| 111 | + if end: |
| 112 | + end = end.strftime(time_format) |
| 113 | + |
| 114 | + return (duration, start, end) |
| 115 | + |
| 116 | + |
| 117 | +def compute_duration_udf(duration, start, end): |
| 118 | + """Calculates duration from start and end time if null.""" |
| 119 | + return compute_time(duration, start, end)[0] |
| 120 | + |
| 121 | + |
| 122 | +def compute_start_udf(duration, start, end): |
| 123 | + """Calculates start time from duration and end time if null.""" |
| 124 | + return compute_time(duration, start, end)[1] |
| 125 | + |
| 126 | + |
| 127 | +def compute_end_udf(duration, start, end): |
| 128 | + """Calculates end time from duration and start time if null.""" |
| 129 | + return compute_time(duration, start, end)[2] |
| 130 | + |
| 131 | + |
| 132 | +if __name__ == "__main__": |
| 133 | + TABLE = sys.argv[1] |
| 134 | + BUCKET_NAME = sys.argv[2] |
| 135 | + |
| 136 | + # Create a SparkSession, viewable via the Spark UI |
| 137 | + spark = SparkSession.builder.appName("data_cleaning").getOrCreate() |
| 138 | + |
| 139 | + # Load data into dataframe if table exists |
| 140 | + try: |
| 141 | + df = spark.read.format("bigquery").option("table", TABLE).load() |
| 142 | + except Py4JJavaError as e: |
| 143 | + raise Exception(f"Error reading {TABLE}") from e |
| 144 | + |
| 145 | + # Single-parameter udfs |
| 146 | + udfs = { |
| 147 | + "start_station_name": UserDefinedFunction(station_name_udf, StringType()), |
| 148 | + "end_station_name": UserDefinedFunction(station_name_udf, StringType()), |
| 149 | + "tripduration": UserDefinedFunction(trip_duration_udf, IntegerType()), |
| 150 | + "usertype": UserDefinedFunction(user_type_udf, StringType()), |
| 151 | + "gender": UserDefinedFunction(gender_udf, StringType()), |
| 152 | + "start_station_latitude": UserDefinedFunction(angle_udf, FloatType()), |
| 153 | + "start_station_longitude": UserDefinedFunction(angle_udf, FloatType()), |
| 154 | + "end_station_latitude": UserDefinedFunction(angle_udf, FloatType()), |
| 155 | + "end_station_longitude": UserDefinedFunction(angle_udf, FloatType()), |
| 156 | + } |
| 157 | + |
| 158 | + for name, udf in udfs.items(): |
| 159 | + df = df.withColumn(name, udf(name)) |
| 160 | + |
| 161 | + # Multi-parameter udfs |
| 162 | + multi_udfs = { |
| 163 | + "tripduration": { |
| 164 | + "udf": UserDefinedFunction(compute_duration_udf, IntegerType()), |
| 165 | + "params": ("tripduration", "starttime", "stoptime"), |
| 166 | + }, |
| 167 | + "starttime": { |
| 168 | + "udf": UserDefinedFunction(compute_start_udf, StringType()), |
| 169 | + "params": ("tripduration", "starttime", "stoptime"), |
| 170 | + }, |
| 171 | + "stoptime": { |
| 172 | + "udf": UserDefinedFunction(compute_end_udf, StringType()), |
| 173 | + "params": ("tripduration", "starttime", "stoptime"), |
| 174 | + }, |
| 175 | + } |
| 176 | + |
| 177 | + for name, obj in multi_udfs.items(): |
| 178 | + df = df.withColumn(name, obj["udf"](*obj["params"])) |
| 179 | + |
| 180 | + # Display sample of rows |
| 181 | + df.show(n=20) |
| 182 | + |
| 183 | + # Write results to GCS |
| 184 | + if "--dry-run" in sys.argv: |
| 185 | + print("Data will not be uploaded to GCS") |
| 186 | + else: |
| 187 | + # Set GCS temp location |
| 188 | + path = str(time.time()) |
| 189 | + temp_path = "gs://" + BUCKET_NAME + "/" + path |
| 190 | + |
| 191 | + # Write dataframe to temp location to preserve the data in final location |
| 192 | + # This takes time, so final location should not be overwritten with partial data |
| 193 | + print("Uploading data to GCS...") |
| 194 | + ( |
| 195 | + df.write |
| 196 | + # gzip the output file |
| 197 | + .options(codec="org.apache.hadoop.io.compress.GzipCodec") |
| 198 | + # write as csv |
| 199 | + .csv(temp_path) |
| 200 | + ) |
| 201 | + |
| 202 | + # Get GCS bucket |
| 203 | + storage_client = storage.Client() |
| 204 | + source_bucket = storage_client.get_bucket(BUCKET_NAME) |
| 205 | + |
| 206 | + # Get all files in temp location |
| 207 | + blobs = list(source_bucket.list_blobs(prefix=path)) |
| 208 | + |
| 209 | + # Copy files from temp location to the final location |
| 210 | + # This is much quicker than the original write to the temp location |
| 211 | + final_path = "clean_data/" |
| 212 | + for blob in blobs: |
| 213 | + file_match = re.match(path + r"/(part-\d*)[0-9a-zA-Z\-]*.csv.gz", blob.name) |
| 214 | + if file_match: |
| 215 | + new_blob = final_path + file_match[1] + ".csv.gz" |
| 216 | + source_bucket.copy_blob(blob, source_bucket, new_blob) |
| 217 | + |
| 218 | + # Delete the temp location |
| 219 | + for blob in blobs: |
| 220 | + blob.delete() |
| 221 | + |
| 222 | + print( |
| 223 | + "Data successfully uploaded to " + "gs://" + BUCKET_NAME + "/" + final_path |
| 224 | + ) |
0 commit comments