Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
fail-fast: false
matrix:
include:
- spark-version: 3.5.0
- spark-version: 3.5.4
scala-version: 2.12.18
python-version: 3.9.19
runs-on: ubuntu-22.04
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/scala-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ jobs:
fail-fast: false
matrix:
include:
- spark-version: 3.5.0
- spark-version: 3.5.4
scala-version: 2.13.8
- spark-version: 3.5.0
- spark-version: 3.5.4
scala-version: 2.12.12
runs-on: ubuntu-22.04
env:
Expand Down
253 changes: 131 additions & 122 deletions src/main/scala/org/graphframes/lib/ConnectedComponents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -281,140 +281,149 @@ object ConnectedComponents extends Logging {
return runGraphX(graph)
}

val runId = UUID.randomUUID().toString.takeRight(8)
val logPrefix = s"[CC $runId]"
logInfo(s"$logPrefix Start connected components with run ID $runId.")

val spark = graph.spark
val sc = spark.sparkContext

val shouldCheckpoint = checkpointInterval > 0
val checkpointDir: Option[String] = if (shouldCheckpoint) {
val dir = sc.getCheckpointDir.map { d =>
new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString
}.getOrElse {
throw new IOException(
"Checkpoint directory is not set. Please set it first using sc.setCheckpointDir().")
// Store original AQE setting
val originalAQE = spark.conf.get("spark.sql.adaptive.enabled")

try {
spark.conf.set("spark.sql.adaptive.enabled", "false")

val runId = UUID.randomUUID().toString.takeRight(8)
val logPrefix = s"[CC $runId]"
logInfo(s"$logPrefix Start connected components with run ID $runId.")

val shouldCheckpoint = checkpointInterval > 0
val checkpointDir: Option[String] = if (shouldCheckpoint) {
val dir = sc.getCheckpointDir.map { d =>
new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString
}.getOrElse {
throw new IOException(
"Checkpoint directory is not set. Please set it first using sc.setCheckpointDir().")
}
logInfo(s"$logPrefix Using $dir for checkpointing with interval $checkpointInterval.")
Some(dir)
} else {
logInfo(
s"$logPrefix Checkpointing is disabled because checkpointInterval=$checkpointInterval.")
None
}
logInfo(s"$logPrefix Using $dir for checkpointing with interval $checkpointInterval.")
Some(dir)
} else {
logInfo(
s"$logPrefix Checkpointing is disabled because checkpointInterval=$checkpointInterval.")
None
}

logInfo(s"$logPrefix Preparing the graph for connected component computation ...")
val g = prepare(graph)
val vv = g.vertices
var ee = g.edges.persist(intermediateStorageLevel) // src < dst
val numEdges = ee.count()
logInfo(s"$logPrefix Found $numEdges edges after preparation.")

var converged = false
var iteration = 1

def _calcMinNbrSum(minNbrsDF: DataFrame): BigDecimal = {
// Taking the sum in DecimalType to preserve precision.
// We use 20 digits for long values and Spark SQL will add 10 digits for the sum.
// It should be able to handle 200 billion edges without overflow.
val (minNbrSum, cnt) = minNbrsDF.select(sum(col(MIN_NBR).cast(DecimalType(20, 0))), count("*")).rdd
.map { r =>
(r.getAs[BigDecimal](0), r.getLong(1))
}.first()
if (cnt != 0L && minNbrSum == null) {
throw new ArithmeticException(
s"""
|The total sum of edge src IDs is used to determine convergence during iterations.
|However, the total sum at iteration $iteration exceeded 30 digits (1e30),
|which should happen only if the graph contains more than 200 billion edges.
|If not, please file a bug report at https://github.com/graphframes/graphframes/issues.
""".stripMargin)
logInfo(s"$logPrefix Preparing the graph for connected component computation ...")
val g = prepare(graph)
val vv = g.vertices
var ee = g.edges.persist(intermediateStorageLevel) // src < dst
val numEdges = ee.count()
logInfo(s"$logPrefix Found $numEdges edges after preparation.")

var converged = false
var iteration = 1

def _calcMinNbrSum(minNbrsDF: DataFrame): BigDecimal = {
// Taking the sum in DecimalType to preserve precision.
// We use 20 digits for long values and Spark SQL will add 10 digits for the sum.
// It should be able to handle 200 billion edges without overflow.
val (minNbrSum, cnt) = minNbrsDF.select(sum(col(MIN_NBR).cast(DecimalType(20, 0))), count("*")).rdd
.map { r =>
(r.getAs[BigDecimal](0), r.getLong(1))
}.first()
if (cnt != 0L && minNbrSum == null) {
throw new ArithmeticException(
s"""
|The total sum of edge src IDs is used to determine convergence during iterations.
|However, the total sum at iteration $iteration exceeded 30 digits (1e30),
|which should happen only if the graph contains more than 200 billion edges.
|If not, please file a bug report at https://github.com/graphframes/graphframes/issues.
""".stripMargin)
}
minNbrSum
}
minNbrSum
}
// compute min neighbors (including self-min)
var minNbrs1: DataFrame = minNbrs(ee) // src >= min_nbr
.persist(intermediateStorageLevel)

var prevSum: BigDecimal = _calcMinNbrSum(minNbrs1)

var lastRoundPersistedDFs = Seq[DataFrame](ee, minNbrs1)
while (!converged) {
var currRoundPersistedDFs = Seq[DataFrame]()
// large-star step
// connect all strictly larger neighbors to the min neighbor (including self)
ee = skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix)
.select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst
.distinct()
// compute min neighbors (including self-min)
var minNbrs1: DataFrame = minNbrs(ee) // src >= min_nbr
.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ ee

// small-star step
// compute min neighbors (excluding self-min)
val minNbrs2 = ee.groupBy(col(SRC)).agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr
.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs2

// connect all smaller neighbors to the min neighbor
ee = skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix)
.select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst
.filter(col(SRC) =!= col(DST)) // src < dst
// connect self to the min neighbor
ee = ee.union(minNbrs2.select(col(MIN_NBR).as(SRC), col(SRC).as(DST))) // src < dst
.distinct()

// checkpointing
if (shouldCheckpoint && (iteration % checkpointInterval == 0)) {
// TODO: remove this after DataFrame.checkpoint is implemented
val out = s"${checkpointDir.get}/$iteration"
ee.write.parquet(out)
// may hit S3 eventually consistent issue
ee = spark.read.parquet(out)

// remove previous checkpoint
if (iteration > checkpointInterval) {
val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}")
path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
var prevSum: BigDecimal = _calcMinNbrSum(minNbrs1)

var lastRoundPersistedDFs = Seq[DataFrame](ee, minNbrs1)
while (!converged) {
var currRoundPersistedDFs = Seq[DataFrame]()
// large-star step
// connect all strictly larger neighbors to the min neighbor (including self)
ee = skewedJoin(ee, minNbrs1, broadcastThreshold, logPrefix)
.select(col(DST).as(SRC), col(MIN_NBR).as(DST)) // src > dst
.distinct()
.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ ee

// small-star step
// compute min neighbors (excluding self-min)
val minNbrs2 = ee.groupBy(col(SRC)).agg(min(col(DST)).as(MIN_NBR), count("*").as(CNT)) // src > min_nbr
.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs2

// connect all smaller neighbors to the min neighbor
ee = skewedJoin(ee, minNbrs2, broadcastThreshold, logPrefix)
.select(col(MIN_NBR).as(SRC), col(DST)) // src <= dst
.filter(col(SRC) =!= col(DST)) // src < dst
// connect self to the min neighbor
ee = ee.union(minNbrs2.select(col(MIN_NBR).as(SRC), col(SRC).as(DST))) // src < dst
.distinct()

// checkpointing
if (shouldCheckpoint && (iteration % checkpointInterval == 0)) {
// TODO: remove this after DataFrame.checkpoint is implemented
val out = s"${checkpointDir.get}/$iteration"
ee.write.parquet(out)
// may hit S3 eventually consistent issue
ee = spark.read.parquet(out)

// remove previous checkpoint
if (iteration > checkpointInterval) {
val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}")
path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
}

System.gc() // hint Spark to clean shuffle directories
}

System.gc() // hint Spark to clean shuffle directories
}

ee.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ ee

minNbrs1 = minNbrs(ee) // src >= min_nbr
.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1

// test convergence
val currSum = _calcMinNbrSum(minNbrs1)
logInfo(s"$logPrefix Sum of assigned components in iteration $iteration: $currSum.")
if (currSum == prevSum) {
// This also covers the case when cnt = 0 and currSum is null, which means no edges.
converged = true
} else {
prevSum = currSum
}
ee.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ ee

minNbrs1 = minNbrs(ee) // src >= min_nbr
.persist(intermediateStorageLevel)
currRoundPersistedDFs = currRoundPersistedDFs :+ minNbrs1

// test convergence
val currSum = _calcMinNbrSum(minNbrs1)
logInfo(s"$logPrefix Sum of assigned components in iteration $iteration: $currSum.")
if (currSum == prevSum) {
// This also covers the case when cnt = 0 and currSum is null, which means no edges.
converged = true
} else {
prevSum = currSum
}

// materialize all persisted DataFrames in current round,
// then we can unpersist last round persisted DataFrames.
for (persisted_df <- currRoundPersistedDFs) {
persisted_df.count() // materialize it.
}
for (persisted_df <- lastRoundPersistedDFs) {
persisted_df.unpersist()
// materialize all persisted DataFrames in current round,
// then we can unpersist last round persisted DataFrames.
for (persisted_df <- currRoundPersistedDFs) {
persisted_df.count() // materialize it.
}
for (persisted_df <- lastRoundPersistedDFs) {
persisted_df.unpersist()
}
lastRoundPersistedDFs = currRoundPersistedDFs
iteration += 1
}
lastRoundPersistedDFs = currRoundPersistedDFs
iteration += 1
}

logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.")
logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.")

logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.")
vv.join(ee, vv(ID) === ee(DST), "left_outer")
.select(vv(ATTR), when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(COMPONENT))
.select(col(s"$ATTR.*"), col(COMPONENT))
logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.")
vv.join(ee, vv(ID) === ee(DST), "left_outer")
.select(vv(ATTR), when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(COMPONENT))
.select(col(s"$ATTR.*"), col(COMPONENT))
} finally {
// Restore original AQE setting
spark.conf.set("spark.sql.adaptive.enabled", originalAQE)
}
}
}