Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit bd6dd1a

Browse files
committed
Added a partition preserving flag to MapPartitionsWithSplitRDD.
1 parent f24bfd2 commit bd6dd1a

File tree

2 files changed

+37
-33
lines changed

2 files changed

+37
-33
lines changed

core/src/main/scala/spark/RDD.scala

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
package spark
22

33
import java.io.EOFException
4-
import java.net.URL
54
import java.io.ObjectInputStream
6-
import java.util.concurrent.atomic.AtomicLong
5+
import java.net.URL
76
import java.util.Random
87
import java.util.Date
98
import java.util.{HashMap => JHashMap}
9+
import java.util.concurrent.atomic.AtomicLong
1010

11-
import scala.collection.mutable.ArrayBuffer
1211
import scala.collection.Map
13-
import scala.collection.mutable.HashMap
1412
import scala.collection.JavaConversions.mapAsScalaMap
13+
import scala.collection.mutable.ArrayBuffer
14+
import scala.collection.mutable.HashMap
1515

1616
import org.apache.hadoop.io.BytesWritable
1717
import org.apache.hadoop.io.NullWritable
@@ -47,7 +47,7 @@ import spark.storage.StorageLevel
4747
import SparkContext._
4848

4949
/**
50-
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
50+
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
5151
* partitioned collection of elements that can be operated on in parallel. This class contains the
5252
* basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition,
5353
* [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such
@@ -86,28 +86,28 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
8686
@transient val dependencies: List[Dependency[_]]
8787

8888
// Methods available on all RDDs:
89-
89+
9090
/** Record user function generating this RDD. */
9191
private[spark] val origin = Utils.getSparkCallSite
92-
92+
9393
/** Optionally overridden by subclasses to specify how they are partitioned. */
9494
val partitioner: Option[Partitioner] = None
9595

9696
/** Optionally overridden by subclasses to specify placement preferences. */
9797
def preferredLocations(split: Split): Seq[String] = Nil
98-
98+
9999
/** The [[spark.SparkContext]] that this RDD was created on. */
100100
def context = sc
101101

102102
private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
103-
103+
104104
/** A unique ID for this RDD (within its SparkContext). */
105105
val id = sc.newRddId()
106-
106+
107107
// Variables relating to persistence
108108
private var storageLevel: StorageLevel = StorageLevel.NONE
109-
110-
/**
109+
110+
/**
111111
* Set this RDD's storage level to persist its values across operations after the first time
112112
* it is computed. Can only be called once on each RDD.
113113
*/
@@ -123,32 +123,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
123123

124124
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
125125
def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY)
126-
126+
127127
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
128128
def cache(): RDD[T] = persist()
129129

130130
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
131131
def getStorageLevel = storageLevel
132-
132+
133133
private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
134134
if (!level.useDisk && level.replication < 2) {
135135
throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
136-
}
137-
136+
}
137+
138138
// This is a hack. Ideally this should re-use the code used by the CacheTracker
139139
// to generate the key.
140140
def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
141-
141+
142142
persist(level)
143143
sc.runJob(this, (iter: Iterator[T]) => {} )
144-
144+
145145
val p = this.partitioner
146-
146+
147147
new BlockRDD[T](sc, splits.map(getSplitKey).toArray) {
148-
override val partitioner = p
148+
override val partitioner = p
149149
}
150150
}
151-
151+
152152
/**
153153
* Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
154154
* This should ''not'' be called by users directly, but is available for implementors of custom
@@ -161,9 +161,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
161161
compute(split)
162162
}
163163
}
164-
164+
165165
// Transformations (return a new RDD)
166-
166+
167167
/**
168168
* Return a new RDD by applying a function to all elements of this RDD.
169169
*/
@@ -199,13 +199,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
199199
var multiplier = 3.0
200200
var initialCount = count()
201201
var maxSelected = 0
202-
202+
203203
if (initialCount > Integer.MAX_VALUE - 1) {
204204
maxSelected = Integer.MAX_VALUE - 1
205205
} else {
206206
maxSelected = initialCount.toInt
207207
}
208-
208+
209209
if (num > initialCount) {
210210
total = maxSelected
211211
fraction = math.min(multiplier * (maxSelected + 1) / initialCount, 1.0)
@@ -215,14 +215,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
215215
fraction = math.min(multiplier * (num + 1) / initialCount, 1.0)
216216
total = num
217217
}
218-
218+
219219
val rand = new Random(seed)
220220
var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
221-
221+
222222
while (samples.length < total) {
223223
samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
224224
}
225-
225+
226226
Utils.randomizeInPlace(samples, rand).take(total)
227227
}
228228

@@ -290,8 +290,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
290290
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
291291
* of the original partition.
292292
*/
293-
def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
294-
new MapPartitionsWithSplitRDD(this, sc.clean(f))
293+
def mapPartitionsWithSplit[U: ClassManifest](
294+
f: (Int, Iterator[T]) => Iterator[U],
295+
preservesPartitioning: Boolean = false): RDD[U] =
296+
new MapPartitionsWithSplitRDD(this, sc.clean(f), preservesPartitioning)
295297

296298
// Actions (launch a job to return a value to the user program)
297299

@@ -342,7 +344,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
342344

343345
/**
344346
* Aggregate the elements of each partition, and then the results for all the partitions, using a
345-
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
347+
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
346348
* modify t1 and return it as its result value to avoid object allocation; however, it should not
347349
* modify t2.
348350
*/
@@ -443,7 +445,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
443445
val evaluator = new GroupedCountEvaluator[T](splits.size, confidence)
444446
sc.runApproximateJob(this, countPartition, evaluator, timeout)
445447
}
446-
448+
447449
/**
448450
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
449451
* it will be slow if a lot of partitions are required. In that case, use collect() to get the

core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ import spark.Split
1212
private[spark]
1313
class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
1414
prev: RDD[T],
15-
f: (Int, Iterator[T]) => Iterator[U])
15+
f: (Int, Iterator[T]) => Iterator[U],
16+
preservesPartitioning: Boolean)
1617
extends RDD[U](prev.context) {
1718

19+
override val partitioner = if (preservesPartitioning) prev.partitioner else None
1820
override def splits = prev.splits
1921
override val dependencies = List(new OneToOneDependency(prev))
2022
override def compute(split: Split) = f(split.index, prev.iterator(split))

0 commit comments

Comments
 (0)