Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
JVM_OPTS: -Xms2048M -Xmx2048M -Xss6M -XX:ReservedCodeCacheSize=256M -Dfile.encoding=UTF-8
SPARK_VERSION: ${{ matrix.spark-version }}
SCALA_VERSION: ${{ matrix.scala-version }}
PIP_REQUESTS_TIMEOUT: 100
steps:
- uses: actions/checkout@v4
- uses: actions/setup-java@v4
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ on:
jobs:
release:
runs-on: ubuntu-latest
env:
PIP_REQUESTS_TIMEOUT: 100
environment:
name: pypi
url: https://pypi.org/p/graphframes-py
Expand Down
107 changes: 71 additions & 36 deletions core/src/main/scala/org/graphframes/GraphFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.functions.array
import org.apache.spark.sql.functions.broadcast
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.count
import org.apache.spark.sql.functions.countDistinct
import org.apache.spark.sql.functions.explode
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.lit
Expand Down Expand Up @@ -570,6 +571,72 @@ class GraphFrame private (
}
}

/**
* Validates the consistency and integrity of a graph by performing checks on the vertices and
* edges.
*
* @return
* Unit, as the method, performs validation checks and throws an exception if validation
* fails.
* @throws InvalidGraphException
* if there are any inconsistencies in the graph, such as duplicate vertices, mismatched
* vertices between edges and vertex DataFrames or missing connections.
*/
def validate(): Unit =
validate(checkVertices = true, intermediateStorageLevel = StorageLevel.MEMORY_AND_DISK)

/**
* Validates the consistency and integrity of a graph by performing checks on the vertices and
* edges.
*
* @param checkVertices
* a flag to indicate whether additional vertex consistency checks should be performed. If
* true, the method will verify that all vertices in the vertex DataFrame are represented in
* the edge DataFrame and vice versa. It is slow on big graphs.
* @param intermediateStorageLevel
* the storage level to be used when persisting intermediate DataFrame computations during the
* validation process.
* @return
* Unit, as the method, performs validation checks and throws an exception if validation
* fails.
* @throws InvalidGraphException
* if there are any inconsistencies in the graph, such as duplicate vertices, mismatched
* vertices between edges and vertex DataFrames or missing connections.
*/
def validate(checkVertices: Boolean, intermediateStorageLevel: StorageLevel): Unit = {
val persistedVertices = vertices.persist(intermediateStorageLevel)
val countDistinctVertices = persistedVertices.select(countDistinct(ID)).first().getLong(0)
val verticesCount = persistedVertices.count()
if (countDistinctVertices != verticesCount) {
throw new InvalidGraphException(
s"Graph contains (${verticesCount - countDistinctVertices}) duplicate vertices.")
}
if (checkVertices) {
val verticesSetFromEdges = edges
.select(col(SRC).alias(ID))
.union(edges.select(col(DST).alias(ID)))
.distinct()
.persist(intermediateStorageLevel)
val countVerticesFromEdges = verticesSetFromEdges.count()
if (countVerticesFromEdges > countDistinctVertices) {
throw new InvalidGraphException(
s"Graph is inconsistent: edges has ${countVerticesFromEdges} " +
s"vertices, but vertices has ${countDistinctVertices} vertices.")
}

val combined = verticesSetFromEdges.join(vertices, ID, "left_anti")
val countOfBadVertices = combined.count()
if (countOfBadVertices > 0) {
throw new InvalidGraphException(
"Vertices DataFrame does not contain all edges src/dst. " +
s"Found ${countOfBadVertices} edges src/dst that are not in the vertices DataFrame.")
}
persistedVertices.unpersist()
verticesSetFromEdges.unpersist()
()
}
}

// ========= Motif finding (private) =========

/**
Expand Down Expand Up @@ -627,7 +694,6 @@ class GraphFrame private (
val withLongIds = vertices
.select(ID)
.repartition(col(ID))
.distinct()
.sortWithinPartitions(ID)
.withColumn(LONG_ID, monotonically_increasing_id())
.persist(StorageLevel.MEMORY_AND_DISK)
Expand Down Expand Up @@ -656,25 +722,11 @@ class GraphFrame private (
col(DST).cast("long").as(LONG_DST),
col(ATTR))
} else {
val threshold = broadcastThreshold
val hubs: Set[Any] = degrees
.filter(col("degree") >= threshold)
.select(ID)
.collect()
.map(_.get(0))
.toSet
val indexedSourceEdges = GraphFrame.skewedJoin(
packedEdges,
indexedVertices.select(col(ID).as(SRC), col(LONG_ID).as(LONG_SRC)),
SRC,
hubs,
"GraphFrame.indexedEdges:")
val indexedEdges = GraphFrame.skewedJoin(
indexedSourceEdges,
val indexedSourceEdges =
packedEdges.join(indexedVertices.select(col(ID).as(SRC), col(LONG_ID).as(LONG_SRC)), SRC)
val indexedEdges = indexedSourceEdges.join(
indexedVertices.select(col(ID).as(DST), col(LONG_ID).as(LONG_DST)),
DST,
hubs,
"GraphFrame.indexedEdges:")
DST)
indexedEdges.select(SRC, LONG_SRC, DST, LONG_DST, ATTR)
}
}
Expand Down Expand Up @@ -1141,21 +1193,4 @@ object GraphFrame extends Serializable with Logging {
}
}
}

/**
* Controls broadcast threshold in skewed joins. Use normal joins for vertices with degrees less
* than the threshold, and broadcast joins otherwise. The default value is 1000000. If we have
* less than 100 billion edges, this would collect at most 2e11 / 1000000 = 200000 hubs, which
* could be handled by the driver.
*/
private[this] var _broadcastThreshold: Int = 1000000

private[graphframes] def broadcastThreshold: Int = _broadcastThreshold

// for unit testing only
private[graphframes] def setBroadcastThreshold(value: Int): this.type = {
require(value >= 0)
_broadcastThreshold = value
this
}
}
9 changes: 9 additions & 0 deletions core/src/main/scala/org/graphframes/exceptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,12 @@ class GraphFramesUnreachableException()
* A detailed error message describing the issue.
*/
class InvalidPropertyGroupException(message: String) extends Exception(message)

/**
* Exception thrown when the graph is invalid, e.g. duplicate vertices, inconsistency between
* vertex set and edges src / dst, etc.
*
* @param message
* A descriptive error message providing details about why the graph operation is invalid.
*/
class InvalidGraphException(message: String) extends Exception(message)
64 changes: 17 additions & 47 deletions core/src/test/scala/org/graphframes/GraphFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {
super.afterAll()
}

test("test validate") {
val goodG = GraphFrame(
spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "attr"),
spark.createDataFrame(Seq((1L, 2L), (2L, 1L), (2L, 3L))).toDF("src", "dst"))
goodG.validate() // no exception should be thrown

val notDistinctVertices = GraphFrame(
spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"), (1L, "d"))).toDF("id", "attr"),
spark.createDataFrame(Seq((1L, 2L), (2L, 1L), (2L, 3L))).toDF("src", "dst"))
assertThrows[InvalidGraphException](notDistinctVertices.validate())

val missingVertices = GraphFrame(
spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "attr"),
spark.createDataFrame(Seq((1L, 2L), (2L, 1L), (2L, 3L), (1L, 4L))).toDF("src", "dst"))
assertThrows[InvalidGraphException](missingVertices.validate())
}

test("construction from DataFrames") {
val g = GraphFrame(vertices, edges)
g.vertices.collect().foreach {
Expand Down Expand Up @@ -364,53 +381,6 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {
nullable = false))))
}

test("skewed long ID assignments") {
val spark = this.spark
import spark.implicits._
val n = 5L
// union a star graph and a chain graph and cast integral IDs to strings
val star = Graphs.star(n)
val chain = Graphs.chain(n + 1)
val vertices = star.vertices.select(col(ID).cast("string").as(ID))
val edges =
star.edges
.select(col(SRC).cast("string").as(SRC), col(DST).cast("string").as(DST))
.unionAll(
chain.edges.select(col(SRC).cast("string").as(SRC), col(DST).cast("string").as(DST)))

val localVertices = vertices.select(ID).as[String].collect().toSet
val localEdges = edges.select(SRC, DST).as[(String, String)].collect().toSet

val defaultThreshold = GraphFrame.broadcastThreshold
assert(
defaultThreshold === 1000000,
s"Default broadcast threshold should be 1000000 but got $defaultThreshold.")

for (threshold <- Seq(0, 4, 10)) {
GraphFrame.setBroadcastThreshold(threshold)

val g = GraphFrame(vertices, edges)
g.persist(StorageLevel.MEMORY_AND_DISK)

val indexedVertices =
g.indexedVertices.select(ID, LONG_ID).as[(String, Long)].collect().toMap
assert(indexedVertices.keySet === localVertices)
assert(indexedVertices.values.toSeq.distinct.size === localVertices.size)
val origEdges = g.indexedEdges.select(SRC, DST).as[(String, String)].collect().toSet
assert(origEdges === localEdges)
g.indexedEdges
.select(SRC, LONG_SRC, DST, LONG_DST)
.as[(String, Long, String, Long)]
.collect()
.foreach { case (src, longSrc, dst, longDst) =>
assert(indexedVertices(src) === longSrc)
assert(indexedVertices(dst) === longDst)
}
}

GraphFrame.setBroadcastThreshold(defaultThreshold)
}

test("power iteration clustering wrapper") {
val spark = this.spark
import spark.implicits._
Expand Down