Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/main/scala/org/graphframes/lib/ConnectedComponents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.graphframes.GraphFrame
import org.graphframes.Logging
import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithCheckpointInterval
import org.graphframes.WithMaxIter

import java.io.IOException
import java.math.BigDecimal
Expand All @@ -45,7 +46,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
extends Arguments
with Logging
with WithAlgorithmChoice
with WithCheckpointInterval {
with WithCheckpointInterval
with WithMaxIter {

private var broadcastThreshold: Int = 1000000
setAlgorithm(ALGO_GRAPHFRAMES)
Expand Down Expand Up @@ -105,7 +107,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
runInGraphX = algorithm == ALGO_GRAPHX,
broadcastThreshold = broadcastThreshold,
checkpointInterval = checkpointInterval,
intermediateStorageLevel = intermediateStorageLevel)
intermediateStorageLevel = intermediateStorageLevel,
maxIter = maxIter)
}
}

Expand Down Expand Up @@ -205,9 +208,9 @@ object ConnectedComponents extends Logging {
new ConnectedComponents(graph).run()
}

private def runGraphX(graph: GraphFrame): DataFrame = {
private def runGraphX(graph: GraphFrame, maxIter: Int): DataFrame = {
val components =
org.apache.spark.graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX)
org.apache.spark.graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX, maxIter)
GraphXConversions.fromGraphX(graph, components, vertexNames = Seq(COMPONENT)).vertices
}

Expand All @@ -216,9 +219,10 @@ object ConnectedComponents extends Logging {
runInGraphX: Boolean,
broadcastThreshold: Int,
checkpointInterval: Int,
intermediateStorageLevel: StorageLevel): DataFrame = {
intermediateStorageLevel: StorageLevel,
maxIter: Option[Int]): DataFrame = {
if (runInGraphX) {
return runGraphX(graph)
return runGraphX(graph, maxIter.getOrElse(Int.MaxValue))
}

val spark = graph.spark
Expand Down
16 changes: 4 additions & 12 deletions src/main/scala/org/graphframes/lib/LabelPropagation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.graphframes.lib
import org.apache.spark.graphx.{lib => graphxlib}
import org.apache.spark.sql.DataFrame
import org.graphframes.GraphFrame
import org.graphframes.WithMaxIter

/**
* Run static Label Propagation for detecting communities in networks.
Expand All @@ -35,18 +36,9 @@ import org.graphframes.GraphFrame
* The resulting DataFrame contains all the original vertex information and one additional column:
* - label (`LongType`): label of community affiliation
*/
class LabelPropagation private[graphframes] (private val graph: GraphFrame) extends Arguments {

private var maxIter: Option[Int] = None

/**
* The max number of iterations of LPA to be performed. Because this is a static implementation,
* the algorithm will run for exactly this many iterations.
*/
def maxIter(value: Int): this.type = {
maxIter = Some(value)
this
}
class LabelPropagation private[graphframes] (private val graph: GraphFrame)
extends Arguments
with WithMaxIter {

def run(): DataFrame = {
LabelPropagation.run(graph, check(maxIter, "maxIter"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.graphframes.lib

import org.apache.spark.graphx.{lib => graphxlib}
import org.graphframes.GraphFrame
import org.graphframes.WithMaxIter

/**
* Parallel Personalized PageRank algorithm implementation.
Expand Down Expand Up @@ -52,10 +53,10 @@ import org.graphframes.GraphFrame
* - weight (`DoubleType`): the normalized weight of this edge after running PageRank
*/
class ParallelPersonalizedPageRank private[graphframes] (private val graph: GraphFrame)
extends Arguments {
extends Arguments
with WithMaxIter {

private var resetProb: Option[Double] = Some(0.15)
private var maxIter: Option[Int] = None
private var srcIds: Array[Any] = Array()

/** Source vertices for a Personalized Page Rank */
Expand All @@ -70,12 +71,6 @@ class ParallelPersonalizedPageRank private[graphframes] (private val graph: Grap
this
}

/** Number of iterations to run */
def maxIter(value: Int): this.type = {
this.maxIter = Some(value)
this
}

def run(): GraphFrame = {
require(maxIter != None, "Max number of iterations maxIter() must be provided")
require(srcIds.nonEmpty, "Source vertices Ids sourceIds() must be provided")
Expand Down
13 changes: 5 additions & 8 deletions src/main/scala/org/graphframes/lib/SVDPlusPlus.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.graphframes.GraphFrame
import org.graphframes.GraphFramesUnreachableException
import org.graphframes.WithMaxIter

/**
* Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative
Expand All @@ -39,9 +40,10 @@ import org.graphframes.GraphFramesUnreachableException
* Returns a DataFrame with vertex attributes containing the trained model. See the object
* (static) members for the names of the output columns.
*/
class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) extends Arguments {
class SVDPlusPlus private[graphframes] (private val graph: GraphFrame)
extends Arguments
with WithMaxIter {
private var _rank: Int = 10
private var _maxIter: Int = 2
private var _minVal: Double = 0.0
private var _maxVal: Double = 5.0
private var _gamma1: Double = 0.007
Expand All @@ -56,11 +58,6 @@ class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) extends A
this
}

def maxIter(value: Int): this.type = {
_maxIter = value
this
}

def minValue(value: Double): this.type = {
_minVal = value
this
Expand Down Expand Up @@ -94,7 +91,7 @@ class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) extends A
def run(): DataFrame = {
val conf = new graphxlib.SVDPlusPlus.Conf(
rank = _rank,
maxIters = _maxIter,
maxIters = maxIter.getOrElse(2),
minVal = _minVal,
maxVal = _maxVal,
gamma1 = _gamma1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.graphframes.lib
import org.apache.spark.graphx.{lib => graphxlib}
import org.apache.spark.sql.DataFrame
import org.graphframes.GraphFrame
import org.graphframes.WithMaxIter

/**
* Compute the strongly connected component (SCC) of each vertex and return a DataFrame with each
Expand All @@ -29,14 +30,8 @@ import org.graphframes.GraphFrame
* - component (`LongType`): unique ID for this component
*/
class StronglyConnectedComponents private[graphframes] (private val graph: GraphFrame)
extends Arguments {

private var maxIter: Option[Int] = None

def maxIter(value: Int): this.type = {
maxIter = Some(value)
this
}
extends Arguments
with WithMaxIter {

def run(): DataFrame = {
StronglyConnectedComponents.run(graph, check(maxIter, "maxIter"))
Expand Down
12 changes: 12 additions & 0 deletions src/main/scala/org/graphframes/mixins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,15 @@ private[graphframes] trait WithCheckpointInterval extends Logging {
*/
def getCheckpointInterval: Int = checkpointInterval
}

private[graphframes] trait WithMaxIter {
protected var maxIter: Option[Int] = None

/**
* The max number of iterations of algorithm to be performed.
*/
def maxIter(value: Int): this.type = {
maxIter = Some(value)
this
}
}