diff --git a/client-spark/common/pom.xml b/client-spark/common/pom.xml
index 60e0d21eb42..b115d0cd4f8 100644
--- a/client-spark/common/pom.xml
+++ b/client-spark/common/pom.xml
@@ -75,6 +75,10 @@
spark-sql_${scala.binary.version}
provided
+
+ org.apache.commons
+ commons-crypto
+
org.mockito
mockito-core
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
index 84d74f8c145..697f86776ba 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
@@ -19,11 +19,17 @@
import java.util.Collections;
import java.util.Map;
+import java.util.Optional;
+
+import scala.Option;
import org.apache.spark.SparkConf;
+import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
+import org.apache.spark.internal.config.package$;
import org.apache.spark.memory.SparkOutOfMemoryError;
+import org.apache.celeborn.client.security.CryptoHandler;
import org.apache.celeborn.reflect.DynConstructors;
import org.apache.celeborn.reflect.DynMethods;
@@ -96,4 +102,19 @@ public static void throwSparkOutOfMemoryError() {
}
}
}
+
+ public static Optional getCryptoHandler(SparkConf conf) {
+ if (!(Boolean) conf.get(package$.MODULE$.IO_ENCRYPTION_ENABLED())) {
+ return Optional.empty();
+ }
+ SparkEnv env = SparkEnv.get();
+ if (env == null) {
+ return Optional.empty();
+ }
+ Option key = env.securityManager().getIOEncryptionKey();
+ if (!key.isDefined()) {
+ return Optional.empty();
+ }
+ return Optional.of(new SparkCryptoHandler(conf, key.get()));
+ }
}
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java
new file mode 100644
index 00000000000..84af2b34b92
--- /dev/null
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.security.CryptoStreamUtils;
+
+import org.apache.celeborn.client.security.CryptoHandler;
+
+public class SparkCryptoHandler implements CryptoHandler {
+ private final SparkConf sparkConf;
+ private final byte[] key;
+
+ public SparkCryptoHandler(SparkConf sparkConf, byte[] key) {
+ this.sparkConf = sparkConf;
+ this.key = key;
+ }
+
+ @Override
+ public byte[] encrypt(byte[] input, int offset, int length) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(baos);
+ dos.writeInt(length);
+ try (OutputStream cos = CryptoStreamUtils.createCryptoOutputStream(dos, sparkConf, key)) {
+ cos.write(input, offset, length);
+ }
+ return baos.toByteArray();
+ }
+
+ @Override
+ public byte[] decrypt(byte[] input, int offset, int length) throws IOException {
+ ByteArrayInputStream bais = new ByteArrayInputStream(input, offset, length);
+ DataInputStream dis = new DataInputStream(bais);
+ int decryptedLength = dis.readInt();
+ // The encrypted payload format is: [4-byte plaintext length][ciphertext...].
+ // So the maximum valid decrypted length is length - 4 (the ciphertext portion).
+ // A value outside this range indicates corruption or a wrong key.
+ if (decryptedLength < 0 || decryptedLength > length - 4) {
+ throw new IOException(
+ "Invalid decrypted length: " + decryptedLength + ", encrypted length: " + length);
+ }
+ try (DataInputStream cis =
+ new DataInputStream(CryptoStreamUtils.createCryptoInputStream(dis, sparkConf, key))) {
+ byte[] decrypted = new byte[decryptedLength];
+ cis.readFully(decrypted);
+ return decrypted;
+ }
+ }
+}
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java
new file mode 100644
index 00000000000..6dd16c36ac1
--- /dev/null
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import static org.junit.Assert.*;
+
+import java.io.IOException;
+import java.security.SecureRandom;
+import java.util.Arrays;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.internal.config.package$;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.celeborn.client.security.CryptoHandler;
+
+public class SparkCryptoHandlerSuiteJ {
+
+ private byte[] key;
+ private CryptoHandler handler;
+
+ @Before
+ public void setUp() {
+ key = new byte[16];
+ new SecureRandom().nextBytes(key);
+ SparkConf sparkConf = new SparkConf(false);
+ sparkConf.set(package$.MODULE$.IO_ENCRYPTION_ENABLED(), true);
+ handler = new SparkCryptoHandler(sparkConf, key);
+ }
+
+ @Test
+ public void testRoundTrip() throws IOException {
+ byte[] plaintext = "hello world, this is a test of encryption".getBytes();
+
+ byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length);
+ assertFalse(
+ "Encrypted output should differ from plaintext", Arrays.equals(plaintext, encrypted));
+
+ byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length);
+ assertArrayEquals(plaintext, decrypted);
+ }
+
+ @Test
+ public void testEncryptedDiffersFromPlaintext() throws IOException {
+ byte[] plaintext = "deterministic test data for comparison".getBytes();
+
+ byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length);
+ assertFalse(
+ "Encrypted output should differ from plaintext", Arrays.equals(plaintext, encrypted));
+ }
+
+ @Test
+ public void testSameDataEncryptsThenDecrypts() throws IOException {
+ byte[] plaintext = "same data encrypted twice".getBytes();
+
+ byte[] encrypted1 = handler.encrypt(plaintext, 0, plaintext.length);
+ byte[] encrypted2 = handler.encrypt(plaintext, 0, plaintext.length);
+
+ // Both should decrypt to the same plaintext
+ byte[] decrypted1 = handler.decrypt(encrypted1, 0, encrypted1.length);
+ byte[] decrypted2 = handler.decrypt(encrypted2, 0, encrypted2.length);
+
+ assertArrayEquals(plaintext, decrypted1);
+ assertArrayEquals(plaintext, decrypted2);
+ }
+
+ @Test
+ public void testEncryptWithOffset() throws IOException {
+ byte[] actual = "offset test data".getBytes();
+ int offset = 10;
+ byte[] padded = new byte[offset + actual.length + 20];
+ System.arraycopy(actual, 0, padded, offset, actual.length);
+
+ byte[] encrypted = handler.encrypt(padded, offset, actual.length);
+ byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length);
+
+ assertArrayEquals(actual, decrypted);
+ }
+
+ @Test
+ public void testDecryptWithWrongKeyFails() throws IOException {
+ byte[] plaintext = "secret data".getBytes();
+ byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length);
+
+ byte[] wrongKey = new byte[16];
+ new SecureRandom().nextBytes(wrongKey);
+ SparkConf sparkConf = new SparkConf(false);
+ sparkConf.set(package$.MODULE$.IO_ENCRYPTION_ENABLED(), true);
+ CryptoHandler wrongHandler = new SparkCryptoHandler(sparkConf, wrongKey);
+
+ byte[] decrypted = null;
+ try {
+ decrypted = wrongHandler.decrypt(encrypted, 0, encrypted.length);
+ } catch (IOException e) {
+ // acceptable — some implementations throw on wrong key
+ return;
+ }
+ // CryptoStreamUtils may return garbage instead of throwing
+ assertFalse(
+ "Decryption with wrong key should not produce original plaintext",
+ Arrays.equals(plaintext, decrypted));
+ }
+
+ @Test
+ public void testLargeData() throws IOException {
+ byte[] plaintext = new byte[64 * 1024]; // 64KB
+ new SecureRandom().nextBytes(plaintext);
+
+ byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length);
+ byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length);
+
+ assertArrayEquals(plaintext, decrypted);
+ }
+
+ @Test
+ public void testEmptyData() throws IOException {
+ byte[] encrypted = handler.encrypt(new byte[0], 0, 0);
+
+ byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length);
+ assertEquals(0, decrypted.length);
+ }
+}
diff --git a/client-spark/spark-2-shaded/pom.xml b/client-spark/spark-2-shaded/pom.xml
index 9db62b423e0..7c3e0e9b852 100644
--- a/client-spark/spark-2-shaded/pom.xml
+++ b/client-spark/spark-2-shaded/pom.xml
@@ -72,6 +72,7 @@
com.google.guava:failureaccess
io.netty:*
org.apache.commons:commons-lang3
+ org.apache.commons:commons-crypto
org.roaringbitmap:RoaringBitmap
commons-io:commons-io
diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 02d48a4fbbc..6e929c215da 100644
--- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -19,6 +19,7 @@
import java.io.IOException;
import java.util.Objects;
+import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import scala.Int;
@@ -35,6 +36,7 @@
import org.apache.celeborn.client.LifecycleManager;
import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.client.security.CryptoHandler;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.ShuffleMode;
import org.apache.celeborn.reflect.DynMethods;
@@ -65,6 +67,17 @@ public class SparkShuffleManager implements ShuffleManager {
private ExecutorShuffleIdTracker shuffleIdTracker = new ExecutorShuffleIdTracker();
+ // The IO encryption key is fixed for the app lifetime. Lazily initialized on first
+ // writer/reader call (not in the constructor) to ensure SparkEnv is available.
+ private volatile Optional cryptoHandler = null;
+
+ private Optional getCryptoHandler() {
+ if (cryptoHandler == null) {
+ cryptoHandler = SparkCommonUtils.getCryptoHandler(conf);
+ }
+ return cryptoHandler;
+ }
+
public SparkShuffleManager(SparkConf conf, boolean isDriver) {
SparkCommonUtils.validateAttemptConfig(conf);
this.conf = conf;
@@ -208,7 +221,8 @@ public ShuffleWriter getWriter(
h.lifecycleManagerPort(),
celebornConf,
h.userIdentifier(),
- h.extension());
+ h.extension(),
+ getCryptoHandler());
if (h.stageRerunEnabled()) {
SparkUtils.addFailureListenerIfBarrierTask(client, context, h);
}
@@ -260,7 +274,8 @@ public ShuffleReader getReader(
Int.MaxValue(),
context,
celebornConf,
- shuffleIdTracker);
+ shuffleIdTracker,
+ getCryptoHandler());
}
checkUserClassPathFirst(handle);
return _sortShuffleManager.getReader(handle, startPartition, endPartition, context);
diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 192cc647ded..dc8908205bb 100644
--- a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -18,6 +18,7 @@
package org.apache.spark.shuffle.celeborn
import java.io.IOException
+import java.util.Optional
import java.util.concurrent.{ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import java.util.function.BiFunction
@@ -32,6 +33,7 @@ import org.apache.spark.util.collection.ExternalSorter
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.read.CelebornInputStream
import org.apache.celeborn.client.read.MetricsCallback
+import org.apache.celeborn.client.security.CryptoHandler
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.{CelebornBroadcastException, CelebornIOException, CelebornRuntimeException, PartitionUnRetryAbleException}
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
@@ -45,7 +47,8 @@ class CelebornShuffleReader[K, C](
endMapIndex: Int = Int.MaxValue,
context: TaskContext,
conf: CelebornConf,
- shuffleIdTracker: ExecutorShuffleIdTracker)
+ shuffleIdTracker: ExecutorShuffleIdTracker,
+ cryptoHandler: Optional[CryptoHandler] = Optional.empty())
extends ShuffleReader[K, C] with Logging {
private val dep = handle.dependency
@@ -55,7 +58,8 @@ class CelebornShuffleReader[K, C](
handle.lifecycleManagerPort,
conf,
handle.userIdentifier,
- handle.extension)
+ handle.extension,
+ cryptoHandler)
private val exceptionRef = new AtomicReference[IOException]
private val encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(context)
diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
index fd888fb9dc1..c32dfaf2eb9 100644
--- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
+++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
@@ -17,12 +17,15 @@
package org.apache.spark.shuffle.celeborn
+import java.util.Optional
+
import org.apache.spark.{ShuffleDependency, TaskContext}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
import org.apache.spark.sql.execution.UnsafeRowSerializer
import org.apache.spark.sql.execution.columnar.{CelebornBatchBuilder, CelebornColumnarBatchSerializer}
+import org.apache.celeborn.client.security.CryptoHandler
import org.apache.celeborn.common.CelebornConf
class CelebornColumnarShuffleReader[K, C](
@@ -34,7 +37,8 @@ class CelebornColumnarShuffleReader[K, C](
context: TaskContext,
conf: CelebornConf,
metrics: ShuffleReadMetricsReporter,
- shuffleIdTracker: ExecutorShuffleIdTracker)
+ shuffleIdTracker: ExecutorShuffleIdTracker,
+ cryptoHandler: Optional[CryptoHandler] = Optional.empty())
extends CelebornShuffleReader[K, C](
handle,
startPartition,
@@ -44,7 +48,8 @@ class CelebornColumnarShuffleReader[K, C](
context,
conf,
metrics,
- shuffleIdTracker) {
+ shuffleIdTracker,
+ cryptoHandler) {
override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
val schema = CustomShuffleDependencyUtils.getSchema(dep)
diff --git a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
index d0f4462be3e..edc67d98335 100644
--- a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
+++ b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.celeborn
+import java.util.Optional
+
import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance}
import org.apache.spark.sql.execution.UnsafeRowSerializer
@@ -58,7 +60,8 @@ class CelebornColumnarShuffleReaderSuite {
taskContext,
new CelebornConf(),
null,
- new ExecutorShuffleIdTracker())
+ new ExecutorShuffleIdTracker(),
+ Optional.empty())
assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]])
} finally {
if (shuffleClient != null) {
@@ -92,7 +95,8 @@ class CelebornColumnarShuffleReaderSuite {
taskContext,
new CelebornConf(),
null,
- new ExecutorShuffleIdTracker())
+ new ExecutorShuffleIdTracker(),
+ Optional.empty())
val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]])
Mockito.when(shuffleDependency.shuffleId).thenReturn(0)
Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer(
diff --git a/client-spark/spark-3-shaded/pom.xml b/client-spark/spark-3-shaded/pom.xml
index bc8c2065e2d..d1e0cf834af 100644
--- a/client-spark/spark-3-shaded/pom.xml
+++ b/client-spark/spark-3-shaded/pom.xml
@@ -76,6 +76,7 @@
com.google.guava:failureaccess
io.netty:*
org.apache.commons:commons-lang3
+ org.apache.commons:commons-crypto
org.roaringbitmap:RoaringBitmap
commons-io:commons-io
diff --git a/client-spark/spark-3.5-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-3.5-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
index fd888fb9dc1..c32dfaf2eb9 100644
--- a/client-spark/spark-3.5-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
+++ b/client-spark/spark-3.5-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
@@ -17,12 +17,15 @@
package org.apache.spark.shuffle.celeborn
+import java.util.Optional
+
import org.apache.spark.{ShuffleDependency, TaskContext}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
import org.apache.spark.sql.execution.UnsafeRowSerializer
import org.apache.spark.sql.execution.columnar.{CelebornBatchBuilder, CelebornColumnarBatchSerializer}
+import org.apache.celeborn.client.security.CryptoHandler
import org.apache.celeborn.common.CelebornConf
class CelebornColumnarShuffleReader[K, C](
@@ -34,7 +37,8 @@ class CelebornColumnarShuffleReader[K, C](
context: TaskContext,
conf: CelebornConf,
metrics: ShuffleReadMetricsReporter,
- shuffleIdTracker: ExecutorShuffleIdTracker)
+ shuffleIdTracker: ExecutorShuffleIdTracker,
+ cryptoHandler: Optional[CryptoHandler] = Optional.empty())
extends CelebornShuffleReader[K, C](
handle,
startPartition,
@@ -44,7 +48,8 @@ class CelebornColumnarShuffleReader[K, C](
context,
conf,
metrics,
- shuffleIdTracker) {
+ shuffleIdTracker,
+ cryptoHandler) {
override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
val schema = CustomShuffleDependencyUtils.getSchema(dep)
diff --git a/client-spark/spark-3.5-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-3.5-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
index d0f4462be3e..edc67d98335 100644
--- a/client-spark/spark-3.5-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
+++ b/client-spark/spark-3.5-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.celeborn
+import java.util.Optional
+
import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance}
import org.apache.spark.sql.execution.UnsafeRowSerializer
@@ -58,7 +60,8 @@ class CelebornColumnarShuffleReaderSuite {
taskContext,
new CelebornConf(),
null,
- new ExecutorShuffleIdTracker())
+ new ExecutorShuffleIdTracker(),
+ Optional.empty())
assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]])
} finally {
if (shuffleClient != null) {
@@ -92,7 +95,8 @@ class CelebornColumnarShuffleReaderSuite {
taskContext,
new CelebornConf(),
null,
- new ExecutorShuffleIdTracker())
+ new ExecutorShuffleIdTracker(),
+ Optional.empty())
val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]])
Mockito.when(shuffleDependency.shuffleId).thenReturn(0)
Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer(
diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index ed7865e19ff..3e45ba2c8ec 100644
--- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -19,6 +19,7 @@
import java.io.IOException;
import java.util.Objects;
+import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.spark.*;
@@ -33,6 +34,7 @@
import org.apache.celeborn.client.LifecycleManager;
import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.client.security.CryptoHandler;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.ShuffleMode;
import org.apache.celeborn.reflect.DynMethods;
@@ -91,6 +93,17 @@ public class SparkShuffleManager implements ShuffleManager {
private ExecutorShuffleIdTracker shuffleIdTracker = new ExecutorShuffleIdTracker();
+ // The IO encryption key is fixed for the app lifetime. Lazily initialized on first
+ // writer/reader call (not in the constructor) to ensure SparkEnv is available.
+ private volatile Optional cryptoHandler = null;
+
+ private Optional getCryptoHandler() {
+ if (cryptoHandler == null) {
+ cryptoHandler = SparkCommonUtils.getCryptoHandler(conf);
+ }
+ return cryptoHandler;
+ }
+
public SparkShuffleManager(SparkConf conf, boolean isDriver) {
if (conf.getBoolean(SQLConf.LOCAL_SHUFFLE_READER_ENABLED().key(), true)) {
logger.warn(
@@ -288,7 +301,8 @@ public ShuffleWriter getWriter(
h.lifecycleManagerPort(),
celebornConf,
h.userIdentifier(),
- h.extension());
+ h.extension(),
+ getCryptoHandler());
if (h.stageRerunEnabled()) {
SparkUtils.addFailureListenerIfBarrierTask(shuffleClient, context, h);
}
@@ -445,7 +459,8 @@ public ShuffleReader getCelebornShuffleReader(
context,
celebornConf,
metrics,
- shuffleIdTracker);
+ shuffleIdTracker,
+ getCryptoHandler());
} else {
return new CelebornShuffleReader<>(
h,
@@ -456,7 +471,8 @@ public ShuffleReader getCelebornShuffleReader(
context,
celebornConf,
metrics,
- shuffleIdTracker);
+ shuffleIdTracker,
+ getCryptoHandler());
}
}
diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index 0e68a4b46f0..412d839434e 100644
--- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -23,6 +23,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
@@ -65,6 +66,7 @@
import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.client.security.CryptoHandler;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornRuntimeException;
import org.apache.celeborn.common.network.protocol.TransportMessage;
@@ -280,7 +282,8 @@ private static class ColumnarShuffleReaderConstructorHolder {
TaskContext.class,
CelebornConf.class,
ShuffleReadMetricsReporter.class,
- ExecutorShuffleIdTracker.class)
+ ExecutorShuffleIdTracker.class,
+ Optional.class)
.build();
}
@@ -293,7 +296,8 @@ public static CelebornShuffleReader createColumnarShuffleReader(
TaskContext context,
CelebornConf conf,
ShuffleReadMetricsReporter metrics,
- ExecutorShuffleIdTracker shuffleIdTracker) {
+ ExecutorShuffleIdTracker shuffleIdTracker,
+ Optional cryptoHandler) {
return ColumnarShuffleReaderConstructorHolder.INSTANCE.invoke(
null,
handle,
@@ -304,7 +308,8 @@ public static CelebornShuffleReader createColumnarShuffleReader(
context,
conf,
metrics,
- shuffleIdTracker);
+ shuffleIdTracker,
+ cryptoHandler);
}
// Added in SPARK-32920, for Spark 3.2 and above
diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 155fc088616..32863c6a540 100644
--- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -18,7 +18,7 @@
package org.apache.spark.shuffle.celeborn
import java.io.IOException
-import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet}
+import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Optional, Set => JSet}
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import java.util.function.BiFunction
@@ -40,6 +40,7 @@ import org.apache.spark.util.collection.ExternalSorter
import org.apache.celeborn.client.{ClientUtils, ShuffleClient}
import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
+import org.apache.celeborn.client.security.CryptoHandler
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.{CelebornBroadcastException, CelebornIOException, CelebornRuntimeException, PartitionUnRetryAbleException}
import org.apache.celeborn.common.network.client.TransportClient
@@ -59,7 +60,8 @@ class CelebornShuffleReader[K, C](
conf: CelebornConf,
metrics: ShuffleReadMetricsReporter,
shuffleIdTracker: ExecutorShuffleIdTracker,
- needDecompress: Boolean)
+ needDecompress: Boolean,
+ cryptoHandler: Optional[CryptoHandler])
extends ShuffleReader[K, C] with Logging {
def this(
@@ -81,7 +83,31 @@ class CelebornShuffleReader[K, C](
conf,
metrics,
shuffleIdTracker,
- true)
+ true,
+ Optional.empty())
+
+ def this(
+ handle: CelebornShuffleHandle[K, _, C],
+ startPartition: Int,
+ endPartition: Int,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ context: TaskContext,
+ conf: CelebornConf,
+ metrics: ShuffleReadMetricsReporter,
+ shuffleIdTracker: ExecutorShuffleIdTracker,
+ cryptoHandler: Optional[CryptoHandler]) = this(
+ handle,
+ startPartition,
+ endPartition,
+ startMapIndex,
+ endMapIndex,
+ context,
+ conf,
+ metrics,
+ shuffleIdTracker,
+ true,
+ cryptoHandler)
private val dep = handle.dependency
@@ -92,7 +118,8 @@ class CelebornShuffleReader[K, C](
handle.lifecycleManagerPort,
conf,
handle.userIdentifier,
- handle.extension)
+ handle.extension,
+ cryptoHandler)
private val exceptionRef = new AtomicReference[IOException]
private val stageRerunEnabled = handle.stageRerunEnabled
diff --git a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala
index d2cec3abf88..901d7797b57 100644
--- a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala
+++ b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala
@@ -60,7 +60,7 @@ class CelebornShuffleReaderSuite extends AnyFunSuite {
val tmpFile = Files.createTempFile("test", ".tmp").toFile
mockStatic(classOf[ShuffleClient]).when(() =>
- ShuffleClient.get(any(), any(), any(), any(), any(), any())).thenReturn(
+ ShuffleClient.get(any(), any(), any(), any(), any(), any(), any())).thenReturn(
new DummyShuffleClient(conf, tmpFile))
val shuffleReader =
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
index fd888fb9dc1..c32dfaf2eb9 100644
--- a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
@@ -17,12 +17,15 @@
package org.apache.spark.shuffle.celeborn
+import java.util.Optional
+
import org.apache.spark.{ShuffleDependency, TaskContext}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
import org.apache.spark.sql.execution.UnsafeRowSerializer
import org.apache.spark.sql.execution.columnar.{CelebornBatchBuilder, CelebornColumnarBatchSerializer}
+import org.apache.celeborn.client.security.CryptoHandler
import org.apache.celeborn.common.CelebornConf
class CelebornColumnarShuffleReader[K, C](
@@ -34,7 +37,8 @@ class CelebornColumnarShuffleReader[K, C](
context: TaskContext,
conf: CelebornConf,
metrics: ShuffleReadMetricsReporter,
- shuffleIdTracker: ExecutorShuffleIdTracker)
+ shuffleIdTracker: ExecutorShuffleIdTracker,
+ cryptoHandler: Optional[CryptoHandler] = Optional.empty())
extends CelebornShuffleReader[K, C](
handle,
startPartition,
@@ -44,7 +48,8 @@ class CelebornColumnarShuffleReader[K, C](
context,
conf,
metrics,
- shuffleIdTracker) {
+ shuffleIdTracker,
+ cryptoHandler) {
override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
val schema = CustomShuffleDependencyUtils.getSchema(dep)
diff --git a/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
index d0f4462be3e..edc67d98335 100644
--- a/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
+++ b/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.celeborn
+import java.util.Optional
+
import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance}
import org.apache.spark.sql.execution.UnsafeRowSerializer
@@ -58,7 +60,8 @@ class CelebornColumnarShuffleReaderSuite {
taskContext,
new CelebornConf(),
null,
- new ExecutorShuffleIdTracker())
+ new ExecutorShuffleIdTracker(),
+ Optional.empty())
assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]])
} finally {
if (shuffleClient != null) {
@@ -92,7 +95,8 @@ class CelebornColumnarShuffleReaderSuite {
taskContext,
new CelebornConf(),
null,
- new ExecutorShuffleIdTracker())
+ new ExecutorShuffleIdTracker(),
+ Optional.empty())
val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]])
Mockito.when(shuffleDependency.shuffleId).thenReturn(0)
Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer(
diff --git a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
index 8390a9c1b5a..86e6a2d191e 100644
--- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -30,6 +30,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
@@ -41,6 +42,7 @@
import org.apache.celeborn.client.read.CelebornInputStream;
import org.apache.celeborn.client.read.MetricsCallback;
+import org.apache.celeborn.client.security.CryptoHandler;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.network.client.TransportClientFactory;
@@ -79,6 +81,9 @@ public DummyShuffleClient(CelebornConf conf, File file) throws Exception {
this.shuffleIntegrityCheckEnabled = conf.clientShuffleIntegrityCheckEnabled();
}
+ @Override
+ public void setupCryptoHandler(Optional cryptoHandler) {}
+
@Override
public void setupLifecycleManagerRef(String host, int port) {}
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index 7035478ebf3..6324505e952 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -35,6 +35,7 @@
import org.apache.celeborn.client.read.CelebornInputStream;
import org.apache.celeborn.client.read.MetricsCallback;
+import org.apache.celeborn.client.security.CryptoHandler;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.identity.UserIdentifier;
@@ -90,6 +91,18 @@ public static ShuffleClient get(
CelebornConf conf,
UserIdentifier userIdentifier,
byte[] extension) {
+ return ShuffleClient.get(
+ appUniqueId, driverHost, port, conf, userIdentifier, extension, Optional.empty());
+ }
+
+ public static ShuffleClient get(
+ String appUniqueId,
+ String driverHost,
+ int port,
+ CelebornConf conf,
+ UserIdentifier userIdentifier,
+ byte[] extension,
+ Optional cryptoHandler) {
if (null == _instance || !initialized) {
synchronized (ShuffleClient.class) {
if (null == _instance) {
@@ -102,12 +115,14 @@ public static ShuffleClient get(
_instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier);
_instance.setupLifecycleManagerRef(driverHost, port);
_instance.setExtension(extension);
+ _instance.setupCryptoHandler(cryptoHandler);
initialized = true;
} else if (!initialized) {
_instance.shutdown();
_instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier);
_instance.setupLifecycleManagerRef(driverHost, port);
_instance.setExtension(extension);
+ _instance.setupCryptoHandler(cryptoHandler);
initialized = true;
}
}
@@ -150,6 +165,8 @@ public static void printReadStats(Logger logger) {
String.format("%.2f", (localReadCount * 1.0d / totalReadCount) * 100));
}
+ public abstract void setupCryptoHandler(Optional cryptoHandler);
+
public abstract void setupLifecycleManagerRef(String host, int port);
public abstract void setupLifecycleManagerRef(RpcEndpointRef endpointRef);
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 358bc227a16..a5970fa0654 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -43,6 +43,7 @@
import org.apache.celeborn.client.compress.Compressor;
import org.apache.celeborn.client.read.CelebornInputStream;
import org.apache.celeborn.client.read.MetricsCallback;
+import org.apache.celeborn.client.security.CryptoHandler;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornBroadcastException;
import org.apache.celeborn.common.exception.CelebornIOException;
@@ -101,6 +102,8 @@ public class ShuffleClientImpl extends ShuffleClient {
protected byte[] extension;
+ private volatile Optional cryptoHandler = Optional.empty();
+
// key: appShuffleIdentifier, value: shuffleId
protected Map> shuffleIdCache = JavaUtils.newConcurrentHashMap();
@@ -1054,6 +1057,22 @@ public int pushOrMergeData(
length = compressor.getCompressedTotalSize();
}
+ // Snapshot volatile field once to avoid a TOCTOU race between isPresent() and get().
+ Optional handler = cryptoHandler;
+ if (handler.isPresent()) {
+ byte[] encrypted = handler.get().encrypt(data, offset, length);
+ logger.debug(
+ "Encrypted shuffle data for shuffle {} map {} partition {}: {} bytes -> {} bytes.",
+ shuffleId,
+ mapId,
+ partitionId,
+ length,
+ encrypted.length);
+ data = encrypted;
+ offset = 0;
+ length = encrypted.length;
+ }
+
final byte[] body = new byte[BATCH_HEADER_SIZE + length];
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, mapId);
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, attemptId);
@@ -2039,7 +2058,8 @@ public CelebornInputStream readPartition(
partitionId,
exceptionMaker,
metricsCallback,
- needDecompress);
+ needDecompress,
+ cryptoHandler);
}
}
@@ -2104,6 +2124,14 @@ public void setExtension(byte[] extension) {
this.extension = extension;
}
+ @Override
+ public void setupCryptoHandler(Optional cryptoHandler) {
+ this.cryptoHandler = cryptoHandler != null ? cryptoHandler : Optional.empty();
+ if (this.cryptoHandler.isPresent()) {
+ logger.info("IO encryption enabled for shuffle data (encryption at rest).");
+ }
+ }
+
boolean mapperEnded(int shuffleId, int mapId) {
return (mapperEndMap.containsKey(shuffleId) && mapperEndMap.get(shuffleId).contains(mapId))
|| isStageEnded(shuffleId);
diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
index 37e0be3e375..bbdaf2056ee 100644
--- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
+++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
@@ -39,6 +39,7 @@
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.compress.Decompressor;
import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata;
+import org.apache.celeborn.client.security.CryptoHandler;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.CommitMetadata;
import org.apache.celeborn.common.exception.CelebornIOException;
@@ -74,7 +75,8 @@ public static CelebornInputStream create(
int partitionId,
ExceptionMaker exceptionMaker,
MetricsCallback metricsCallback,
- boolean needDecompress)
+ boolean needDecompress,
+ Optional cryptoHandler)
throws IOException {
if (locations == null || locations.isEmpty()) {
return emptyInputStream;
@@ -106,7 +108,8 @@ public static CelebornInputStream create(
metricsCallback,
needDecompress,
startMapIndex,
- endMapIndex);
+ endMapIndex,
+ cryptoHandler);
} else {
return new CelebornInputStreamImpl(
conf,
@@ -131,7 +134,8 @@ public static CelebornInputStream create(
metricsCallback,
needDecompress,
-1,
- -1);
+ -1,
+ cryptoHandler);
}
}
}
@@ -188,6 +192,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private final Map failedBatches;
+ private byte[] encryptedBuf;
private byte[] compressedBuf;
private byte[] rawDataBuf;
private Decompressor decompressor;
@@ -223,6 +228,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private int shuffleId;
private int partitionId;
private ExceptionMaker exceptionMaker;
+ private Optional cryptoHandler;
private boolean closed = false;
private boolean integrityChecked = false;
private final CommitMetadata aggregatedActualCommitMetadata = new CommitMetadata();
@@ -250,7 +256,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
MetricsCallback metricsCallback,
boolean needDecompress,
int numberOfSubPartitions,
- int currentIndexOfSubPartition)
+ int currentIndexOfSubPartition,
+ Optional cryptoHandler)
throws IOException {
this(
conf,
@@ -275,7 +282,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
metricsCallback,
needDecompress,
numberOfSubPartitions,
- currentIndexOfSubPartition);
+ currentIndexOfSubPartition,
+ cryptoHandler);
}
CelebornInputStreamImpl(
@@ -301,7 +309,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
MetricsCallback metricsCallback,
boolean needDecompress,
int numberOfSubPartitions,
- int currentIndexOfSubPartition)
+ int currentIndexOfSubPartition,
+ Optional cryptoHandler)
throws IOException {
this.conf = conf;
this.clientFactory = clientFactory;
@@ -337,6 +346,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
this.retryWaitMs = conf.networkIoRetryWaitMs(TransportModuleConstants.DATA_MODULE);
this.callback = metricsCallback;
this.exceptionMaker = exceptionMaker;
+ this.cryptoHandler = cryptoHandler;
this.partitionId = partitionId;
this.appShuffleId = appShuffleId;
this.shuffleId = shuffleId;
@@ -730,6 +740,7 @@ public synchronized void close() {
compressedBuf = null;
rawDataBuf = null;
+ encryptedBuf = null;
batchesRead = null;
locations = null;
attempts = null;
@@ -790,9 +801,14 @@ private boolean moveToNextChunk() throws IOException {
private void init() {
int bufferSize = conf.clientFetchBufferSize();
+ int headerLen = shouldDecompress ? Decompressor.getCompressionHeaderLength(conf) : 0;
+ if (cryptoHandler.isPresent()) {
+ // Size to match compressedBuf/rawDataBuf headroom; exact overhead depends on the
+ // CryptoHandler implementation (e.g. SparkCryptoHandler prepends a 4-byte length).
+ encryptedBuf = new byte[bufferSize + headerLen];
+ }
if (shouldDecompress) {
- int headerLen = Decompressor.getCompressionHeaderLength(conf);
bufferSize += headerLen;
compressedBuf = new byte[bufferSize];
decompressor = Decompressor.getDecompressor(conf);
@@ -823,65 +839,94 @@ private boolean fillBuffer() throws IOException {
int batchId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8);
int size = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 12);
- if (shouldDecompress) {
+ // encryptedSize is the on-wire byte count (used for metrics); size will be
+ // reassigned to the decrypted length after decryption.
+ int encryptedSize = size;
+
+ // Perform dedup/stale-attempt checks before decrypting to avoid paying the
+ // crypto cost for batches that will be discarded anyway.
+ if (attemptId != attempts[mapId]) {
+ currentChunk.skipBytes(size);
+ continue;
+ }
+ if (readSkewPartitionWithoutMapRange) {
+ LocationPushFailedBatches locationPushFailedBatches =
+ this.failedBatches.get(currentReader.getLocation().getUniqueId());
+ if (null != locationPushFailedBatches) {
+ if (locationPushFailedBatches.contains(mapId, attemptId, batchId)) {
+ logger.warn(
+ "Skip duplicated batch: mapId={}, attemptId={}, batchId={}",
+ mapId,
+ attemptId,
+ batchId);
+ currentChunk.skipBytes(size);
+ continue;
+ }
+ }
+ }
+ Set batchSet = batchesRead.computeIfAbsent(mapId, k -> new HashSet<>());
+ if (batchSet.contains(batchId)) {
+ callback.incDuplicateBytesRead(BATCH_HEADER_SIZE + encryptedSize);
+ logger.debug(
+ "Skip duplicated batch: mapId {}, attemptId {}, batchId {}.",
+ mapId,
+ attemptId,
+ batchId);
+ currentChunk.skipBytes(size);
+ continue;
+ }
+
+ // Batch is unique and from the correct attempt — now read and optionally decrypt.
+ if (cryptoHandler.isPresent()) {
+ if (size > encryptedBuf.length) {
+ encryptedBuf = new byte[size];
+ }
+ currentChunk.readBytes(encryptedBuf, 0, size);
+ byte[] decrypted = cryptoHandler.get().decrypt(encryptedBuf, 0, size);
+ if (logger.isDebugEnabled()) {
+ logger.debug(
+ "Decrypted shuffle data for shuffle {} partition {}: {} bytes -> {} bytes.",
+ shuffleId,
+ partitionId,
+ size,
+ decrypted.length);
+ }
+ size = decrypted.length;
+ if (shouldDecompress) {
+ compressedBuf = decrypted;
+ } else {
+ rawDataBuf = decrypted;
+ }
+ } else if (shouldDecompress) {
if (size > compressedBuf.length) {
compressedBuf = new byte[size];
}
-
currentChunk.readBytes(compressedBuf, 0, size);
} else {
if (size > rawDataBuf.length) {
rawDataBuf = new byte[size];
}
-
currentChunk.readBytes(rawDataBuf, 0, size);
}
- // de-duplicate
- if (attemptId == attempts[mapId]) {
- if (readSkewPartitionWithoutMapRange) {
- LocationPushFailedBatches locationPushFailedBatches =
- this.failedBatches.get(currentReader.getLocation().getUniqueId());
- if (null != locationPushFailedBatches) {
- if (locationPushFailedBatches.contains(mapId, attemptId, batchId)) {
- logger.warn(
- "Skip duplicated batch: mapId={}, attemptId={}, batchId={}",
- mapId,
- attemptId,
- batchId);
- continue;
- }
- }
- }
- Set batchSet = batchesRead.computeIfAbsent(mapId, k -> new HashSet<>());
- if (!batchSet.contains(batchId)) {
- batchSet.add(batchId);
- callback.incBytesRead(BATCH_HEADER_SIZE + size);
- if (shouldDecompress) {
- // decompress data
- int originalLength = decompressor.getOriginalLen(compressedBuf);
- if (rawDataBuf.length < originalLength) {
- rawDataBuf = new byte[originalLength];
- }
- limit = decompressor.decompress(compressedBuf, rawDataBuf, 0);
- } else {
- limit = size;
- }
- if (shuffleIntegrityCheckEnabled) {
- aggregatedActualCommitMetadata.addDataWithOffsetAndLength(rawDataBuf, 0, limit);
- }
- position = 0;
- hasData = true;
- break;
- } else {
- callback.incDuplicateBytesRead(BATCH_HEADER_SIZE + size);
- logger.debug(
- "Skip duplicated batch: mapId {}, attemptId {}, batchId {}.",
- mapId,
- attemptId,
- batchId);
+ batchSet.add(batchId);
+ callback.incBytesRead(BATCH_HEADER_SIZE + encryptedSize);
+ if (shouldDecompress) {
+ // decompress data
+ int originalLength = decompressor.getOriginalLen(compressedBuf);
+ if (rawDataBuf.length < originalLength) {
+ rawDataBuf = new byte[originalLength];
}
+ limit = decompressor.decompress(compressedBuf, rawDataBuf, 0);
+ } else {
+ limit = size;
+ }
+ if (shuffleIntegrityCheckEnabled) {
+ aggregatedActualCommitMetadata.addDataWithOffsetAndLength(rawDataBuf, 0, limit);
}
+ position = 0;
+ hasData = true;
+ break;
}
if (!hasData) {
diff --git a/client/src/main/java/org/apache/celeborn/client/security/CryptoHandler.java b/client/src/main/java/org/apache/celeborn/client/security/CryptoHandler.java
new file mode 100644
index 00000000000..38da34fa797
--- /dev/null
+++ b/client/src/main/java/org/apache/celeborn/client/security/CryptoHandler.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client.security;
+
+import java.io.IOException;
+
+public interface CryptoHandler {
+ byte[] encrypt(byte[] input, int offset, int length) throws IOException;
+
+ byte[] decrypt(byte[] input, int offset, int length) throws IOException;
+}
diff --git a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java
new file mode 100644
index 00000000000..394898a7a24
--- /dev/null
+++ b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client.read;
+
+import static org.junit.Assert.*;
+import static org.mockito.ArgumentMatchers.*;
+import static org.mockito.Mockito.*;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Optional;
+import java.util.concurrent.ConcurrentHashMap;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.junit.Test;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.client.compress.Compressor;
+import org.apache.celeborn.client.security.CryptoHandler;
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
+import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.client.TransportClientFactory;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import org.apache.celeborn.common.protocol.MessageType;
+import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbStreamHandler;
+import org.apache.celeborn.common.protocol.StorageInfo;
+import org.apache.celeborn.common.unsafe.Platform;
+
+/**
+ * Integration-style round-trip tests for EAR (Encryption At Rest) wiring in {@link
+ * CelebornInputStream}. These tests verify that the encrypt-on-write / decrypt-on-read path works
+ * end-to-end, including interactions with compression and the shuffle integrity check.
+ */
+public class CelebornInputStreamCryptoRoundTripSuiteJ {
+
+ private static final int BATCH_HEADER_SIZE = 16;
+ private static final String SHUFFLE_KEY = "app-1-1";
+
+ /**
+ * A minimal CryptoHandler for testing: the encrypted format is [4-byte plaintext length
+ * (int)][XOR-encrypted payload]. This matches the structural contract of SparkCryptoHandler so
+ * the bounds check (decryptedLength > length - 4) is also exercised.
+ */
+ static class XorCryptoHandler implements CryptoHandler {
+ private final byte key;
+
+ XorCryptoHandler(byte key) {
+ this.key = key;
+ }
+
+ @Override
+ public byte[] encrypt(byte[] input, int offset, int length) throws IOException {
+ // Prefix with 4-byte plaintext length, then XOR-encrypt the payload
+ byte[] out = new byte[4 + length];
+ Platform.putInt(out, Platform.BYTE_ARRAY_OFFSET, length);
+ for (int i = 0; i < length; i++) {
+ out[4 + i] = (byte) (input[offset + i] ^ key);
+ }
+ return out;
+ }
+
+ @Override
+ public byte[] decrypt(byte[] input, int offset, int length) throws IOException {
+ // Validate the buffer is large enough to hold the 4-byte length prefix
+ if (length < 4) {
+ throw new IOException("Encrypted buffer too short: " + length);
+ }
+ // Read the plaintext length from the 4-byte prefix
+ int decryptedLength = Platform.getInt(input, Platform.BYTE_ARRAY_OFFSET + offset);
+ // Validate bounds: the 4-byte prefix must fit inside the encrypted buffer
+ if (decryptedLength < 0 || decryptedLength > length - 4) {
+ throw new IOException(
+ "Invalid decrypted length: " + decryptedLength + ", encrypted length: " + length);
+ }
+ byte[] out = new byte[decryptedLength];
+ for (int i = 0; i < decryptedLength; i++) {
+ out[i] = (byte) (input[offset + 4 + i] ^ key);
+ }
+ return out;
+ }
+ }
+
+ /**
+ * Build a single batch ByteBuf as ShuffleClientImpl.pushOrMergeData does: optionally compress,
+ * optionally encrypt, then prepend the 16-byte batch header.
+ */
+ private ByteBuf buildBatch(
+ byte[] plaintext, boolean compress, CryptoHandler cryptoHandler, CelebornConf conf)
+ throws IOException {
+ byte[] data = plaintext;
+ int offset = 0;
+ int length = plaintext.length;
+
+ // Step 1: optionally compress (compress-then-encrypt ordering matches ShuffleClientImpl)
+ if (compress) {
+ Compressor compressor = Compressor.getCompressor(conf);
+ compressor.compress(data, offset, length);
+ data = compressor.getCompressedBuffer();
+ offset = 0;
+ length = compressor.getCompressedTotalSize();
+ }
+
+ // Step 2: optionally encrypt the (possibly compressed) payload
+ if (cryptoHandler != null) {
+ data = cryptoHandler.encrypt(data, offset, length);
+ offset = 0;
+ length = data.length;
+ }
+
+ // Step 3: prepend the 16-byte batch header [mapId|attemptId|batchId|payloadLen]
+ byte[] body = new byte[BATCH_HEADER_SIZE + length];
+ Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, 0); // mapId
+ Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, 0); // attemptId
+ Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 8, 0); // batchId
+ Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, length); // payload length
+ System.arraycopy(data, offset, body, BATCH_HEADER_SIZE, length);
+ return Unpooled.wrappedBuffer(body);
+ }
+
+ /**
+ * Create a CelebornInputStream backed by a mock TransportClient that serves the given batchBuf as
+ * a single chunk.
+ */
+ private CelebornInputStream createStream(
+ ByteBuf batchBuf,
+ boolean needDecompress,
+ Optional cryptoHandler,
+ CelebornConf conf)
+ throws IOException, InterruptedException {
+ TransportClient client = mock(TransportClient.class);
+ PbStreamHandler pbHandler =
+ PbStreamHandler.newBuilder().setStreamId(1L).setNumChunks(1).build();
+ // Encode the stream handler into an RPC response that CelebornInputStream expects
+ ByteBuffer rpcResponse =
+ new TransportMessage(MessageType.STREAM_HANDLER, pbHandler.toByteArray()).toByteBuffer();
+ when(client.sendRpcSync(any(ByteBuffer.class), anyLong())).thenReturn(rpcResponse);
+ doNothing().when(client).sendRpc(any(ByteBuffer.class));
+ doAnswer(
+ invocation -> {
+ ChunkReceivedCallback cb = invocation.getArgument(3);
+ // Serve the pre-built batch buffer immediately as chunk 0; duplicate() shares
+ // the underlying data without incrementing the ref count, so the stream's
+ // single release correctly frees the buffer.
+ cb.onSuccess(0, new NettyManagedBuffer(batchBuf.duplicate()));
+ return null;
+ })
+ .when(client)
+ .fetchChunk(anyLong(), anyInt(), anyLong(), any(ChunkReceivedCallback.class));
+
+ TransportClientFactory clientFactory = mock(TransportClientFactory.class);
+ when(clientFactory.createClient(anyString(), anyInt())).thenReturn(client);
+
+ ShuffleClient shuffleClient = mock(ShuffleClient.class);
+
+ // PRIMARY location pointing to a single HDD partition
+ PartitionLocation location =
+ new PartitionLocation(
+ 0, 0, "host1", 9001, 9002, 9003, 9004, PartitionLocation.Mode.PRIMARY);
+ location.setStorageInfo(new StorageInfo(StorageInfo.Type.HDD, true, "/mnt/disk1"));
+
+ ArrayList locations = new ArrayList<>();
+ locations.add(location);
+ ArrayList handlers = new ArrayList<>();
+ handlers.add(PbStreamHandler.newBuilder().setStreamId(1L).setNumChunks(1).build());
+
+ return CelebornInputStream.create(
+ conf,
+ clientFactory,
+ SHUFFLE_KEY,
+ locations,
+ handlers,
+ new int[] {0},
+ new HashMap<>(),
+ new HashMap<>(),
+ 0,
+ 1L,
+ 0,
+ 100,
+ new ConcurrentHashMap<>(),
+ shuffleClient,
+ 1,
+ 1,
+ 0,
+ null,
+ new MetricsCallback() {
+ @Override
+ public void incBytesRead(long bytes) {}
+
+ @Override
+ public void incReadTime(long time) {}
+ },
+ needDecompress,
+ cryptoHandler);
+ }
+
+ private byte[] readAll(CelebornInputStream stream) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ byte[] buf = new byte[4096];
+ int n;
+ while ((n = stream.read(buf)) != -1) {
+ baos.write(buf, 0, n);
+ }
+ return baos.toByteArray();
+ }
+
+ @Test
+ public void testEncryptDecryptRoundTrip() throws IOException, InterruptedException {
+ byte[] plaintext = "hello, EAR round-trip without compression".getBytes();
+ CelebornConf conf = new CelebornConf();
+ XorCryptoHandler handler = new XorCryptoHandler((byte) 0x5A);
+
+ // Build an encrypted batch and read it back through CelebornInputStream
+ ByteBuf batchBuf = buildBatch(plaintext, false, handler, conf);
+ try (CelebornInputStream stream = createStream(batchBuf, false, Optional.of(handler), conf)) {
+ assertArrayEquals(plaintext, readAll(stream));
+ }
+ }
+
+ @Test
+ public void testNoEncryptionRoundTrip() throws IOException, InterruptedException {
+ byte[] plaintext = "unencrypted shuffle data sanity check".getBytes();
+ CelebornConf conf = new CelebornConf();
+
+ // Baseline: with no CryptoHandler the data flows through unchanged
+ ByteBuf batchBuf = buildBatch(plaintext, false, null, conf);
+ try (CelebornInputStream stream = createStream(batchBuf, false, Optional.empty(), conf)) {
+ assertArrayEquals(plaintext, readAll(stream));
+ }
+ }
+
+ @Test
+ public void testCompressThenEncryptRoundTrip() throws IOException, InterruptedException {
+ // Reproduce the compress-then-encrypt ordering used in ShuffleClientImpl.
+ byte[] plaintext = "shuffle data with compression and encryption enabled for EAR".getBytes();
+ CelebornConf conf = new CelebornConf();
+ // Use LZ4 (default)
+ conf.set(CelebornConf.SHUFFLE_COMPRESSION_CODEC().key(), "lz4");
+ XorCryptoHandler handler = new XorCryptoHandler((byte) 0x3C);
+
+ // Writer: LZ4-compress then XOR-encrypt; Reader: decrypt then decompress
+ ByteBuf batchBuf = buildBatch(plaintext, true, handler, conf);
+ try (CelebornInputStream stream = createStream(batchBuf, true, Optional.of(handler), conf)) {
+ assertArrayEquals(plaintext, readAll(stream));
+ }
+ }
+
+ @Test
+ public void testEncryptWithIntegrityCheckEnabled() throws IOException, InterruptedException {
+ // Verify that EAR + shuffle integrity check (celeborn.client.shuffle.integrityCheck.enabled)
+ // work together: the checksum is computed over plaintext, so decrypt-then-verify must hold.
+ byte[] plaintext = "integrity check should pass after decryption".getBytes();
+ CelebornConf conf = new CelebornConf();
+ conf.set(CelebornConf.CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED().key(), "true");
+ XorCryptoHandler handler = new XorCryptoHandler((byte) 0x7F);
+
+ // The integrity metadata (checksum) is added by CelebornInputStream over the decrypted data
+ ByteBuf batchBuf = buildBatch(plaintext, false, handler, conf);
+ try (CelebornInputStream stream = createStream(batchBuf, false, Optional.of(handler), conf)) {
+ assertArrayEquals(plaintext, readAll(stream));
+ }
+ }
+
+ @Test
+ public void testLargePayloadEncryptDecrypt() throws IOException, InterruptedException {
+ // 128 KB payload exercises buffer-boundary handling in fillBuffer()
+ byte[] plaintext = new byte[128 * 1024];
+ for (int i = 0; i < plaintext.length; i++) plaintext[i] = (byte) (i % 251);
+ CelebornConf conf = new CelebornConf();
+ XorCryptoHandler handler = new XorCryptoHandler((byte) 0xAB);
+
+ ByteBuf batchBuf = buildBatch(plaintext, false, handler, conf);
+ try (CelebornInputStream stream = createStream(batchBuf, false, Optional.of(handler), conf)) {
+ assertArrayEquals(plaintext, readAll(stream));
+ }
+ }
+}
diff --git a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamPeerFailoverTest.java b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamPeerFailoverTest.java
index a7dd7db3178..456486b9e8c 100644
--- a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamPeerFailoverTest.java
+++ b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamPeerFailoverTest.java
@@ -32,6 +32,7 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
+import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
@@ -39,6 +40,7 @@
import org.junit.Test;
import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.client.security.CryptoHandler;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.network.client.TransportClient;
@@ -174,7 +176,8 @@ public void testFailureWithoutPeer() throws Exception {
0,
null,
new TestMetricsCallback(),
- false);
+ false,
+ Optional.empty());
}
private void createInputStream(String primaryHost, String replicaHost) throws IOException {
@@ -209,7 +212,8 @@ private void createInputStream(String primaryHost, String replicaHost) throws IO
0,
null,
new TestMetricsCallback(),
- false);
+ false,
+ Optional.empty());
}
/**