diff --git a/client-spark/common/pom.xml b/client-spark/common/pom.xml index 60e0d21eb42..b115d0cd4f8 100644 --- a/client-spark/common/pom.xml +++ b/client-spark/common/pom.xml @@ -75,6 +75,10 @@ spark-sql_${scala.binary.version} provided + + org.apache.commons + commons-crypto + org.mockito mockito-core diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java index 84d74f8c145..697f86776ba 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java @@ -19,11 +19,17 @@ import java.util.Collections; import java.util.Map; +import java.util.Optional; + +import scala.Option; import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; +import org.apache.spark.internal.config.package$; import org.apache.spark.memory.SparkOutOfMemoryError; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.reflect.DynConstructors; import org.apache.celeborn.reflect.DynMethods; @@ -96,4 +102,19 @@ public static void throwSparkOutOfMemoryError() { } } } + + public static Optional getCryptoHandler(SparkConf conf) { + if (!(Boolean) conf.get(package$.MODULE$.IO_ENCRYPTION_ENABLED())) { + return Optional.empty(); + } + SparkEnv env = SparkEnv.get(); + if (env == null) { + return Optional.empty(); + } + Option key = env.securityManager().getIOEncryptionKey(); + if (!key.isDefined()) { + return Optional.empty(); + } + return Optional.of(new SparkCryptoHandler(conf, key.get())); + } } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java new file mode 100644 index 00000000000..84af2b34b92 --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import org.apache.spark.SparkConf; +import org.apache.spark.security.CryptoStreamUtils; + +import org.apache.celeborn.client.security.CryptoHandler; + +public class SparkCryptoHandler implements CryptoHandler { + private final SparkConf sparkConf; + private final byte[] key; + + public SparkCryptoHandler(SparkConf sparkConf, byte[] key) { + this.sparkConf = sparkConf; + this.key = key; + } + + @Override + public byte[] encrypt(byte[] input, int offset, int length) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + dos.writeInt(length); + try (OutputStream cos = CryptoStreamUtils.createCryptoOutputStream(dos, sparkConf, key)) { + cos.write(input, offset, length); + } + return baos.toByteArray(); + } + + @Override + public byte[] decrypt(byte[] input, int offset, int length) throws IOException { + ByteArrayInputStream bais = new ByteArrayInputStream(input, offset, length); + DataInputStream dis = new DataInputStream(bais); + int decryptedLength = dis.readInt(); + // The encrypted payload format is: [4-byte plaintext length][ciphertext...]. + // So the maximum valid decrypted length is length - 4 (the ciphertext portion). + // A value outside this range indicates corruption or a wrong key. + if (decryptedLength < 0 || decryptedLength > length - 4) { + throw new IOException( + "Invalid decrypted length: " + decryptedLength + ", encrypted length: " + length); + } + try (DataInputStream cis = + new DataInputStream(CryptoStreamUtils.createCryptoInputStream(dis, sparkConf, key))) { + byte[] decrypted = new byte[decryptedLength]; + cis.readFully(decrypted); + return decrypted; + } + } +} diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java new file mode 100644 index 00000000000..6dd16c36ac1 --- /dev/null +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn; + +import static org.junit.Assert.*; + +import java.io.IOException; +import java.security.SecureRandom; +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.internal.config.package$; +import org.junit.Before; +import org.junit.Test; + +import org.apache.celeborn.client.security.CryptoHandler; + +public class SparkCryptoHandlerSuiteJ { + + private byte[] key; + private CryptoHandler handler; + + @Before + public void setUp() { + key = new byte[16]; + new SecureRandom().nextBytes(key); + SparkConf sparkConf = new SparkConf(false); + sparkConf.set(package$.MODULE$.IO_ENCRYPTION_ENABLED(), true); + handler = new SparkCryptoHandler(sparkConf, key); + } + + @Test + public void testRoundTrip() throws IOException { + byte[] plaintext = "hello world, this is a test of encryption".getBytes(); + + byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length); + assertFalse( + "Encrypted output should differ from plaintext", Arrays.equals(plaintext, encrypted)); + + byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length); + assertArrayEquals(plaintext, decrypted); + } + + @Test + public void testEncryptedDiffersFromPlaintext() throws IOException { + byte[] plaintext = "deterministic test data for comparison".getBytes(); + + byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length); + assertFalse( + "Encrypted output should differ from plaintext", Arrays.equals(plaintext, encrypted)); + } + + @Test + public void testSameDataEncryptsThenDecrypts() throws IOException { + byte[] plaintext = "same data encrypted twice".getBytes(); + + byte[] encrypted1 = handler.encrypt(plaintext, 0, plaintext.length); + byte[] encrypted2 = handler.encrypt(plaintext, 0, plaintext.length); + + // Both should decrypt to the same plaintext + byte[] decrypted1 = handler.decrypt(encrypted1, 0, encrypted1.length); + byte[] decrypted2 = handler.decrypt(encrypted2, 0, encrypted2.length); + + assertArrayEquals(plaintext, decrypted1); + assertArrayEquals(plaintext, decrypted2); + } + + @Test + public void testEncryptWithOffset() throws IOException { + byte[] actual = "offset test data".getBytes(); + int offset = 10; + byte[] padded = new byte[offset + actual.length + 20]; + System.arraycopy(actual, 0, padded, offset, actual.length); + + byte[] encrypted = handler.encrypt(padded, offset, actual.length); + byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length); + + assertArrayEquals(actual, decrypted); + } + + @Test + public void testDecryptWithWrongKeyFails() throws IOException { + byte[] plaintext = "secret data".getBytes(); + byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length); + + byte[] wrongKey = new byte[16]; + new SecureRandom().nextBytes(wrongKey); + SparkConf sparkConf = new SparkConf(false); + sparkConf.set(package$.MODULE$.IO_ENCRYPTION_ENABLED(), true); + CryptoHandler wrongHandler = new SparkCryptoHandler(sparkConf, wrongKey); + + byte[] decrypted = null; + try { + decrypted = wrongHandler.decrypt(encrypted, 0, encrypted.length); + } catch (IOException e) { + // acceptable — some implementations throw on wrong key + return; + } + // CryptoStreamUtils may return garbage instead of throwing + assertFalse( + "Decryption with wrong key should not produce original plaintext", + Arrays.equals(plaintext, decrypted)); + } + + @Test + public void testLargeData() throws IOException { + byte[] plaintext = new byte[64 * 1024]; // 64KB + new SecureRandom().nextBytes(plaintext); + + byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length); + byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length); + + assertArrayEquals(plaintext, decrypted); + } + + @Test + public void testEmptyData() throws IOException { + byte[] encrypted = handler.encrypt(new byte[0], 0, 0); + + byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length); + assertEquals(0, decrypted.length); + } +} diff --git a/client-spark/spark-2-shaded/pom.xml b/client-spark/spark-2-shaded/pom.xml index 9db62b423e0..7c3e0e9b852 100644 --- a/client-spark/spark-2-shaded/pom.xml +++ b/client-spark/spark-2-shaded/pom.xml @@ -72,6 +72,7 @@ com.google.guava:failureaccess io.netty:* org.apache.commons:commons-lang3 + org.apache.commons:commons-crypto org.roaringbitmap:RoaringBitmap commons-io:commons-io diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 02d48a4fbbc..6e929c215da 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import scala.Int; @@ -35,6 +36,7 @@ import org.apache.celeborn.client.LifecycleManager; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.ShuffleMode; import org.apache.celeborn.reflect.DynMethods; @@ -65,6 +67,17 @@ public class SparkShuffleManager implements ShuffleManager { private ExecutorShuffleIdTracker shuffleIdTracker = new ExecutorShuffleIdTracker(); + // The IO encryption key is fixed for the app lifetime. Lazily initialized on first + // writer/reader call (not in the constructor) to ensure SparkEnv is available. + private volatile Optional cryptoHandler = null; + + private Optional getCryptoHandler() { + if (cryptoHandler == null) { + cryptoHandler = SparkCommonUtils.getCryptoHandler(conf); + } + return cryptoHandler; + } + public SparkShuffleManager(SparkConf conf, boolean isDriver) { SparkCommonUtils.validateAttemptConfig(conf); this.conf = conf; @@ -208,7 +221,8 @@ public ShuffleWriter getWriter( h.lifecycleManagerPort(), celebornConf, h.userIdentifier(), - h.extension()); + h.extension(), + getCryptoHandler()); if (h.stageRerunEnabled()) { SparkUtils.addFailureListenerIfBarrierTask(client, context, h); } @@ -260,7 +274,8 @@ public ShuffleReader getReader( Int.MaxValue(), context, celebornConf, - shuffleIdTracker); + shuffleIdTracker, + getCryptoHandler()); } checkUserClassPathFirst(handle); return _sortShuffleManager.getReader(handle, startPartition, endPartition, context); diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 192cc647ded..dc8908205bb 100644 --- a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.celeborn import java.io.IOException +import java.util.Optional import java.util.concurrent.{ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.function.BiFunction @@ -32,6 +33,7 @@ import org.apache.spark.util.collection.ExternalSorter import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.client.read.CelebornInputStream import org.apache.celeborn.client.read.MetricsCallback +import org.apache.celeborn.client.security.CryptoHandler import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.exception.{CelebornBroadcastException, CelebornIOException, CelebornRuntimeException, PartitionUnRetryAbleException} import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse @@ -45,7 +47,8 @@ class CelebornShuffleReader[K, C]( endMapIndex: Int = Int.MaxValue, context: TaskContext, conf: CelebornConf, - shuffleIdTracker: ExecutorShuffleIdTracker) + shuffleIdTracker: ExecutorShuffleIdTracker, + cryptoHandler: Optional[CryptoHandler] = Optional.empty()) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency @@ -55,7 +58,8 @@ class CelebornShuffleReader[K, C]( handle.lifecycleManagerPort, conf, handle.userIdentifier, - handle.extension) + handle.extension, + cryptoHandler) private val exceptionRef = new AtomicReference[IOException] private val encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(context) diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala index fd888fb9dc1..c32dfaf2eb9 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala @@ -17,12 +17,15 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, TaskContext} import org.apache.spark.serializer.SerializerInstance import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.sql.execution.UnsafeRowSerializer import org.apache.spark.sql.execution.columnar.{CelebornBatchBuilder, CelebornColumnarBatchSerializer} +import org.apache.celeborn.client.security.CryptoHandler import org.apache.celeborn.common.CelebornConf class CelebornColumnarShuffleReader[K, C]( @@ -34,7 +37,8 @@ class CelebornColumnarShuffleReader[K, C]( context: TaskContext, conf: CelebornConf, metrics: ShuffleReadMetricsReporter, - shuffleIdTracker: ExecutorShuffleIdTracker) + shuffleIdTracker: ExecutorShuffleIdTracker, + cryptoHandler: Optional[CryptoHandler] = Optional.empty()) extends CelebornShuffleReader[K, C]( handle, startPartition, @@ -44,7 +48,8 @@ class CelebornColumnarShuffleReader[K, C]( context, conf, metrics, - shuffleIdTracker) { + shuffleIdTracker, + cryptoHandler) { override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { val schema = CustomShuffleDependencyUtils.getSchema(dep) diff --git a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala index d0f4462be3e..edc67d98335 100644 --- a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala +++ b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance} import org.apache.spark.sql.execution.UnsafeRowSerializer @@ -58,7 +60,8 @@ class CelebornColumnarShuffleReaderSuite { taskContext, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty()) assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]]) } finally { if (shuffleClient != null) { @@ -92,7 +95,8 @@ class CelebornColumnarShuffleReaderSuite { taskContext, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty()) val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]]) Mockito.when(shuffleDependency.shuffleId).thenReturn(0) Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer( diff --git a/client-spark/spark-3-shaded/pom.xml b/client-spark/spark-3-shaded/pom.xml index bc8c2065e2d..d1e0cf834af 100644 --- a/client-spark/spark-3-shaded/pom.xml +++ b/client-spark/spark-3-shaded/pom.xml @@ -76,6 +76,7 @@ com.google.guava:failureaccess io.netty:* org.apache.commons:commons-lang3 + org.apache.commons:commons-crypto org.roaringbitmap:RoaringBitmap commons-io:commons-io diff --git a/client-spark/spark-3.5-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-3.5-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala index fd888fb9dc1..c32dfaf2eb9 100644 --- a/client-spark/spark-3.5-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala +++ b/client-spark/spark-3.5-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala @@ -17,12 +17,15 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, TaskContext} import org.apache.spark.serializer.SerializerInstance import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.sql.execution.UnsafeRowSerializer import org.apache.spark.sql.execution.columnar.{CelebornBatchBuilder, CelebornColumnarBatchSerializer} +import org.apache.celeborn.client.security.CryptoHandler import org.apache.celeborn.common.CelebornConf class CelebornColumnarShuffleReader[K, C]( @@ -34,7 +37,8 @@ class CelebornColumnarShuffleReader[K, C]( context: TaskContext, conf: CelebornConf, metrics: ShuffleReadMetricsReporter, - shuffleIdTracker: ExecutorShuffleIdTracker) + shuffleIdTracker: ExecutorShuffleIdTracker, + cryptoHandler: Optional[CryptoHandler] = Optional.empty()) extends CelebornShuffleReader[K, C]( handle, startPartition, @@ -44,7 +48,8 @@ class CelebornColumnarShuffleReader[K, C]( context, conf, metrics, - shuffleIdTracker) { + shuffleIdTracker, + cryptoHandler) { override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { val schema = CustomShuffleDependencyUtils.getSchema(dep) diff --git a/client-spark/spark-3.5-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-3.5-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala index d0f4462be3e..edc67d98335 100644 --- a/client-spark/spark-3.5-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala +++ b/client-spark/spark-3.5-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance} import org.apache.spark.sql.execution.UnsafeRowSerializer @@ -58,7 +60,8 @@ class CelebornColumnarShuffleReaderSuite { taskContext, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty()) assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]]) } finally { if (shuffleClient != null) { @@ -92,7 +95,8 @@ class CelebornColumnarShuffleReaderSuite { taskContext, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty()) val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]]) Mockito.when(shuffleDependency.shuffleId).thenReturn(0) Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer( diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index ed7865e19ff..3e45ba2c8ec 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import org.apache.spark.*; @@ -33,6 +34,7 @@ import org.apache.celeborn.client.LifecycleManager; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.ShuffleMode; import org.apache.celeborn.reflect.DynMethods; @@ -91,6 +93,17 @@ public class SparkShuffleManager implements ShuffleManager { private ExecutorShuffleIdTracker shuffleIdTracker = new ExecutorShuffleIdTracker(); + // The IO encryption key is fixed for the app lifetime. Lazily initialized on first + // writer/reader call (not in the constructor) to ensure SparkEnv is available. + private volatile Optional cryptoHandler = null; + + private Optional getCryptoHandler() { + if (cryptoHandler == null) { + cryptoHandler = SparkCommonUtils.getCryptoHandler(conf); + } + return cryptoHandler; + } + public SparkShuffleManager(SparkConf conf, boolean isDriver) { if (conf.getBoolean(SQLConf.LOCAL_SHUFFLE_READER_ENABLED().key(), true)) { logger.warn( @@ -288,7 +301,8 @@ public ShuffleWriter getWriter( h.lifecycleManagerPort(), celebornConf, h.userIdentifier(), - h.extension()); + h.extension(), + getCryptoHandler()); if (h.stageRerunEnabled()) { SparkUtils.addFailureListenerIfBarrierTask(shuffleClient, context, h); } @@ -445,7 +459,8 @@ public ShuffleReader getCelebornShuffleReader( context, celebornConf, metrics, - shuffleIdTracker); + shuffleIdTracker, + getCryptoHandler()); } else { return new CelebornShuffleReader<>( h, @@ -456,7 +471,8 @@ public ShuffleReader getCelebornShuffleReader( context, celebornConf, metrics, - shuffleIdTracker); + shuffleIdTracker, + getCryptoHandler()); } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 0e68a4b46f0..412d839434e 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -65,6 +66,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornRuntimeException; import org.apache.celeborn.common.network.protocol.TransportMessage; @@ -280,7 +282,8 @@ private static class ColumnarShuffleReaderConstructorHolder { TaskContext.class, CelebornConf.class, ShuffleReadMetricsReporter.class, - ExecutorShuffleIdTracker.class) + ExecutorShuffleIdTracker.class, + Optional.class) .build(); } @@ -293,7 +296,8 @@ public static CelebornShuffleReader createColumnarShuffleReader( TaskContext context, CelebornConf conf, ShuffleReadMetricsReporter metrics, - ExecutorShuffleIdTracker shuffleIdTracker) { + ExecutorShuffleIdTracker shuffleIdTracker, + Optional cryptoHandler) { return ColumnarShuffleReaderConstructorHolder.INSTANCE.invoke( null, handle, @@ -304,7 +308,8 @@ public static CelebornShuffleReader createColumnarShuffleReader( context, conf, metrics, - shuffleIdTracker); + shuffleIdTracker, + cryptoHandler); } // Added in SPARK-32920, for Spark 3.2 and above diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 155fc088616..32863c6a540 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle.celeborn import java.io.IOException -import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet} +import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Optional, Set => JSet} import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.function.BiFunction @@ -40,6 +40,7 @@ import org.apache.spark.util.collection.ExternalSorter import org.apache.celeborn.client.{ClientUtils, ShuffleClient} import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback} +import org.apache.celeborn.client.security.CryptoHandler import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.exception.{CelebornBroadcastException, CelebornIOException, CelebornRuntimeException, PartitionUnRetryAbleException} import org.apache.celeborn.common.network.client.TransportClient @@ -59,7 +60,8 @@ class CelebornShuffleReader[K, C]( conf: CelebornConf, metrics: ShuffleReadMetricsReporter, shuffleIdTracker: ExecutorShuffleIdTracker, - needDecompress: Boolean) + needDecompress: Boolean, + cryptoHandler: Optional[CryptoHandler]) extends ShuffleReader[K, C] with Logging { def this( @@ -81,7 +83,31 @@ class CelebornShuffleReader[K, C]( conf, metrics, shuffleIdTracker, - true) + true, + Optional.empty()) + + def this( + handle: CelebornShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + startMapIndex: Int, + endMapIndex: Int, + context: TaskContext, + conf: CelebornConf, + metrics: ShuffleReadMetricsReporter, + shuffleIdTracker: ExecutorShuffleIdTracker, + cryptoHandler: Optional[CryptoHandler]) = this( + handle, + startPartition, + endPartition, + startMapIndex, + endMapIndex, + context, + conf, + metrics, + shuffleIdTracker, + true, + cryptoHandler) private val dep = handle.dependency @@ -92,7 +118,8 @@ class CelebornShuffleReader[K, C]( handle.lifecycleManagerPort, conf, handle.userIdentifier, - handle.extension) + handle.extension, + cryptoHandler) private val exceptionRef = new AtomicReference[IOException] private val stageRerunEnabled = handle.stageRerunEnabled diff --git a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala index d2cec3abf88..901d7797b57 100644 --- a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala +++ b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala @@ -60,7 +60,7 @@ class CelebornShuffleReaderSuite extends AnyFunSuite { val tmpFile = Files.createTempFile("test", ".tmp").toFile mockStatic(classOf[ShuffleClient]).when(() => - ShuffleClient.get(any(), any(), any(), any(), any(), any())).thenReturn( + ShuffleClient.get(any(), any(), any(), any(), any(), any(), any())).thenReturn( new DummyShuffleClient(conf, tmpFile)) val shuffleReader = diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala index fd888fb9dc1..c32dfaf2eb9 100644 --- a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala @@ -17,12 +17,15 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, TaskContext} import org.apache.spark.serializer.SerializerInstance import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.sql.execution.UnsafeRowSerializer import org.apache.spark.sql.execution.columnar.{CelebornBatchBuilder, CelebornColumnarBatchSerializer} +import org.apache.celeborn.client.security.CryptoHandler import org.apache.celeborn.common.CelebornConf class CelebornColumnarShuffleReader[K, C]( @@ -34,7 +37,8 @@ class CelebornColumnarShuffleReader[K, C]( context: TaskContext, conf: CelebornConf, metrics: ShuffleReadMetricsReporter, - shuffleIdTracker: ExecutorShuffleIdTracker) + shuffleIdTracker: ExecutorShuffleIdTracker, + cryptoHandler: Optional[CryptoHandler] = Optional.empty()) extends CelebornShuffleReader[K, C]( handle, startPartition, @@ -44,7 +48,8 @@ class CelebornColumnarShuffleReader[K, C]( context, conf, metrics, - shuffleIdTracker) { + shuffleIdTracker, + cryptoHandler) { override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { val schema = CustomShuffleDependencyUtils.getSchema(dep) diff --git a/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala index d0f4462be3e..edc67d98335 100644 --- a/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala +++ b/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance} import org.apache.spark.sql.execution.UnsafeRowSerializer @@ -58,7 +60,8 @@ class CelebornColumnarShuffleReaderSuite { taskContext, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty()) assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]]) } finally { if (shuffleClient != null) { @@ -92,7 +95,8 @@ class CelebornColumnarShuffleReaderSuite { taskContext, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty()) val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]]) Mockito.when(shuffleDependency.shuffleId).thenReturn(0) Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer( diff --git a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java index 8390a9c1b5a..86e6a2d191e 100644 --- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -30,6 +30,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -41,6 +42,7 @@ import org.apache.celeborn.client.read.CelebornInputStream; import org.apache.celeborn.client.read.MetricsCallback; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.client.TransportClientFactory; @@ -79,6 +81,9 @@ public DummyShuffleClient(CelebornConf conf, File file) throws Exception { this.shuffleIntegrityCheckEnabled = conf.clientShuffleIntegrityCheckEnabled(); } + @Override + public void setupCryptoHandler(Optional cryptoHandler) {} + @Override public void setupLifecycleManagerRef(String host, int port) {} diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 7035478ebf3..6324505e952 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -35,6 +35,7 @@ import org.apache.celeborn.client.read.CelebornInputStream; import org.apache.celeborn.client.read.MetricsCallback; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.identity.UserIdentifier; @@ -90,6 +91,18 @@ public static ShuffleClient get( CelebornConf conf, UserIdentifier userIdentifier, byte[] extension) { + return ShuffleClient.get( + appUniqueId, driverHost, port, conf, userIdentifier, extension, Optional.empty()); + } + + public static ShuffleClient get( + String appUniqueId, + String driverHost, + int port, + CelebornConf conf, + UserIdentifier userIdentifier, + byte[] extension, + Optional cryptoHandler) { if (null == _instance || !initialized) { synchronized (ShuffleClient.class) { if (null == _instance) { @@ -102,12 +115,14 @@ public static ShuffleClient get( _instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier); _instance.setupLifecycleManagerRef(driverHost, port); _instance.setExtension(extension); + _instance.setupCryptoHandler(cryptoHandler); initialized = true; } else if (!initialized) { _instance.shutdown(); _instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier); _instance.setupLifecycleManagerRef(driverHost, port); _instance.setExtension(extension); + _instance.setupCryptoHandler(cryptoHandler); initialized = true; } } @@ -150,6 +165,8 @@ public static void printReadStats(Logger logger) { String.format("%.2f", (localReadCount * 1.0d / totalReadCount) * 100)); } + public abstract void setupCryptoHandler(Optional cryptoHandler); + public abstract void setupLifecycleManagerRef(String host, int port); public abstract void setupLifecycleManagerRef(RpcEndpointRef endpointRef); diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 358bc227a16..a5970fa0654 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -43,6 +43,7 @@ import org.apache.celeborn.client.compress.Compressor; import org.apache.celeborn.client.read.CelebornInputStream; import org.apache.celeborn.client.read.MetricsCallback; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornBroadcastException; import org.apache.celeborn.common.exception.CelebornIOException; @@ -101,6 +102,8 @@ public class ShuffleClientImpl extends ShuffleClient { protected byte[] extension; + private volatile Optional cryptoHandler = Optional.empty(); + // key: appShuffleIdentifier, value: shuffleId protected Map> shuffleIdCache = JavaUtils.newConcurrentHashMap(); @@ -1054,6 +1057,22 @@ public int pushOrMergeData( length = compressor.getCompressedTotalSize(); } + // Snapshot volatile field once to avoid a TOCTOU race between isPresent() and get(). + Optional handler = cryptoHandler; + if (handler.isPresent()) { + byte[] encrypted = handler.get().encrypt(data, offset, length); + logger.debug( + "Encrypted shuffle data for shuffle {} map {} partition {}: {} bytes -> {} bytes.", + shuffleId, + mapId, + partitionId, + length, + encrypted.length); + data = encrypted; + offset = 0; + length = encrypted.length; + } + final byte[] body = new byte[BATCH_HEADER_SIZE + length]; Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, mapId); Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, attemptId); @@ -2039,7 +2058,8 @@ public CelebornInputStream readPartition( partitionId, exceptionMaker, metricsCallback, - needDecompress); + needDecompress, + cryptoHandler); } } @@ -2104,6 +2124,14 @@ public void setExtension(byte[] extension) { this.extension = extension; } + @Override + public void setupCryptoHandler(Optional cryptoHandler) { + this.cryptoHandler = cryptoHandler != null ? cryptoHandler : Optional.empty(); + if (this.cryptoHandler.isPresent()) { + logger.info("IO encryption enabled for shuffle data (encryption at rest)."); + } + } + boolean mapperEnded(int shuffleId, int mapId) { return (mapperEndMap.containsKey(shuffleId) && mapperEndMap.get(shuffleId).contains(mapId)) || isStageEnded(shuffleId); diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 37e0be3e375..bbdaf2056ee 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -39,6 +39,7 @@ import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.client.compress.Decompressor; import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.CommitMetadata; import org.apache.celeborn.common.exception.CelebornIOException; @@ -74,7 +75,8 @@ public static CelebornInputStream create( int partitionId, ExceptionMaker exceptionMaker, MetricsCallback metricsCallback, - boolean needDecompress) + boolean needDecompress, + Optional cryptoHandler) throws IOException { if (locations == null || locations.isEmpty()) { return emptyInputStream; @@ -106,7 +108,8 @@ public static CelebornInputStream create( metricsCallback, needDecompress, startMapIndex, - endMapIndex); + endMapIndex, + cryptoHandler); } else { return new CelebornInputStreamImpl( conf, @@ -131,7 +134,8 @@ public static CelebornInputStream create( metricsCallback, needDecompress, -1, - -1); + -1, + cryptoHandler); } } } @@ -188,6 +192,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private final Map failedBatches; + private byte[] encryptedBuf; private byte[] compressedBuf; private byte[] rawDataBuf; private Decompressor decompressor; @@ -223,6 +228,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private int shuffleId; private int partitionId; private ExceptionMaker exceptionMaker; + private Optional cryptoHandler; private boolean closed = false; private boolean integrityChecked = false; private final CommitMetadata aggregatedActualCommitMetadata = new CommitMetadata(); @@ -250,7 +256,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { MetricsCallback metricsCallback, boolean needDecompress, int numberOfSubPartitions, - int currentIndexOfSubPartition) + int currentIndexOfSubPartition, + Optional cryptoHandler) throws IOException { this( conf, @@ -275,7 +282,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { metricsCallback, needDecompress, numberOfSubPartitions, - currentIndexOfSubPartition); + currentIndexOfSubPartition, + cryptoHandler); } CelebornInputStreamImpl( @@ -301,7 +309,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { MetricsCallback metricsCallback, boolean needDecompress, int numberOfSubPartitions, - int currentIndexOfSubPartition) + int currentIndexOfSubPartition, + Optional cryptoHandler) throws IOException { this.conf = conf; this.clientFactory = clientFactory; @@ -337,6 +346,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { this.retryWaitMs = conf.networkIoRetryWaitMs(TransportModuleConstants.DATA_MODULE); this.callback = metricsCallback; this.exceptionMaker = exceptionMaker; + this.cryptoHandler = cryptoHandler; this.partitionId = partitionId; this.appShuffleId = appShuffleId; this.shuffleId = shuffleId; @@ -730,6 +740,7 @@ public synchronized void close() { compressedBuf = null; rawDataBuf = null; + encryptedBuf = null; batchesRead = null; locations = null; attempts = null; @@ -790,9 +801,14 @@ private boolean moveToNextChunk() throws IOException { private void init() { int bufferSize = conf.clientFetchBufferSize(); + int headerLen = shouldDecompress ? Decompressor.getCompressionHeaderLength(conf) : 0; + if (cryptoHandler.isPresent()) { + // Size to match compressedBuf/rawDataBuf headroom; exact overhead depends on the + // CryptoHandler implementation (e.g. SparkCryptoHandler prepends a 4-byte length). + encryptedBuf = new byte[bufferSize + headerLen]; + } if (shouldDecompress) { - int headerLen = Decompressor.getCompressionHeaderLength(conf); bufferSize += headerLen; compressedBuf = new byte[bufferSize]; decompressor = Decompressor.getDecompressor(conf); @@ -823,65 +839,94 @@ private boolean fillBuffer() throws IOException { int batchId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8); int size = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 12); - if (shouldDecompress) { + // encryptedSize is the on-wire byte count (used for metrics); size will be + // reassigned to the decrypted length after decryption. + int encryptedSize = size; + + // Perform dedup/stale-attempt checks before decrypting to avoid paying the + // crypto cost for batches that will be discarded anyway. + if (attemptId != attempts[mapId]) { + currentChunk.skipBytes(size); + continue; + } + if (readSkewPartitionWithoutMapRange) { + LocationPushFailedBatches locationPushFailedBatches = + this.failedBatches.get(currentReader.getLocation().getUniqueId()); + if (null != locationPushFailedBatches) { + if (locationPushFailedBatches.contains(mapId, attemptId, batchId)) { + logger.warn( + "Skip duplicated batch: mapId={}, attemptId={}, batchId={}", + mapId, + attemptId, + batchId); + currentChunk.skipBytes(size); + continue; + } + } + } + Set batchSet = batchesRead.computeIfAbsent(mapId, k -> new HashSet<>()); + if (batchSet.contains(batchId)) { + callback.incDuplicateBytesRead(BATCH_HEADER_SIZE + encryptedSize); + logger.debug( + "Skip duplicated batch: mapId {}, attemptId {}, batchId {}.", + mapId, + attemptId, + batchId); + currentChunk.skipBytes(size); + continue; + } + + // Batch is unique and from the correct attempt — now read and optionally decrypt. + if (cryptoHandler.isPresent()) { + if (size > encryptedBuf.length) { + encryptedBuf = new byte[size]; + } + currentChunk.readBytes(encryptedBuf, 0, size); + byte[] decrypted = cryptoHandler.get().decrypt(encryptedBuf, 0, size); + if (logger.isDebugEnabled()) { + logger.debug( + "Decrypted shuffle data for shuffle {} partition {}: {} bytes -> {} bytes.", + shuffleId, + partitionId, + size, + decrypted.length); + } + size = decrypted.length; + if (shouldDecompress) { + compressedBuf = decrypted; + } else { + rawDataBuf = decrypted; + } + } else if (shouldDecompress) { if (size > compressedBuf.length) { compressedBuf = new byte[size]; } - currentChunk.readBytes(compressedBuf, 0, size); } else { if (size > rawDataBuf.length) { rawDataBuf = new byte[size]; } - currentChunk.readBytes(rawDataBuf, 0, size); } - // de-duplicate - if (attemptId == attempts[mapId]) { - if (readSkewPartitionWithoutMapRange) { - LocationPushFailedBatches locationPushFailedBatches = - this.failedBatches.get(currentReader.getLocation().getUniqueId()); - if (null != locationPushFailedBatches) { - if (locationPushFailedBatches.contains(mapId, attemptId, batchId)) { - logger.warn( - "Skip duplicated batch: mapId={}, attemptId={}, batchId={}", - mapId, - attemptId, - batchId); - continue; - } - } - } - Set batchSet = batchesRead.computeIfAbsent(mapId, k -> new HashSet<>()); - if (!batchSet.contains(batchId)) { - batchSet.add(batchId); - callback.incBytesRead(BATCH_HEADER_SIZE + size); - if (shouldDecompress) { - // decompress data - int originalLength = decompressor.getOriginalLen(compressedBuf); - if (rawDataBuf.length < originalLength) { - rawDataBuf = new byte[originalLength]; - } - limit = decompressor.decompress(compressedBuf, rawDataBuf, 0); - } else { - limit = size; - } - if (shuffleIntegrityCheckEnabled) { - aggregatedActualCommitMetadata.addDataWithOffsetAndLength(rawDataBuf, 0, limit); - } - position = 0; - hasData = true; - break; - } else { - callback.incDuplicateBytesRead(BATCH_HEADER_SIZE + size); - logger.debug( - "Skip duplicated batch: mapId {}, attemptId {}, batchId {}.", - mapId, - attemptId, - batchId); + batchSet.add(batchId); + callback.incBytesRead(BATCH_HEADER_SIZE + encryptedSize); + if (shouldDecompress) { + // decompress data + int originalLength = decompressor.getOriginalLen(compressedBuf); + if (rawDataBuf.length < originalLength) { + rawDataBuf = new byte[originalLength]; } + limit = decompressor.decompress(compressedBuf, rawDataBuf, 0); + } else { + limit = size; + } + if (shuffleIntegrityCheckEnabled) { + aggregatedActualCommitMetadata.addDataWithOffsetAndLength(rawDataBuf, 0, limit); } + position = 0; + hasData = true; + break; } if (!hasData) { diff --git a/client/src/main/java/org/apache/celeborn/client/security/CryptoHandler.java b/client/src/main/java/org/apache/celeborn/client/security/CryptoHandler.java new file mode 100644 index 00000000000..38da34fa797 --- /dev/null +++ b/client/src/main/java/org/apache/celeborn/client/security/CryptoHandler.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.security; + +import java.io.IOException; + +public interface CryptoHandler { + byte[] encrypt(byte[] input, int offset, int length) throws IOException; + + byte[] decrypt(byte[] input, int offset, int length) throws IOException; +} diff --git a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java new file mode 100644 index 00000000000..394898a7a24 --- /dev/null +++ b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.read; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.Test; + +import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.compress.Compressor; +import org.apache.celeborn.client.security.CryptoHandler; +import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.network.buffer.NettyManagedBuffer; +import org.apache.celeborn.common.network.client.ChunkReceivedCallback; +import org.apache.celeborn.common.network.client.TransportClient; +import org.apache.celeborn.common.network.client.TransportClientFactory; +import org.apache.celeborn.common.network.protocol.TransportMessage; +import org.apache.celeborn.common.protocol.MessageType; +import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.common.protocol.PbStreamHandler; +import org.apache.celeborn.common.protocol.StorageInfo; +import org.apache.celeborn.common.unsafe.Platform; + +/** + * Integration-style round-trip tests for EAR (Encryption At Rest) wiring in {@link + * CelebornInputStream}. These tests verify that the encrypt-on-write / decrypt-on-read path works + * end-to-end, including interactions with compression and the shuffle integrity check. + */ +public class CelebornInputStreamCryptoRoundTripSuiteJ { + + private static final int BATCH_HEADER_SIZE = 16; + private static final String SHUFFLE_KEY = "app-1-1"; + + /** + * A minimal CryptoHandler for testing: the encrypted format is [4-byte plaintext length + * (int)][XOR-encrypted payload]. This matches the structural contract of SparkCryptoHandler so + * the bounds check (decryptedLength > length - 4) is also exercised. + */ + static class XorCryptoHandler implements CryptoHandler { + private final byte key; + + XorCryptoHandler(byte key) { + this.key = key; + } + + @Override + public byte[] encrypt(byte[] input, int offset, int length) throws IOException { + // Prefix with 4-byte plaintext length, then XOR-encrypt the payload + byte[] out = new byte[4 + length]; + Platform.putInt(out, Platform.BYTE_ARRAY_OFFSET, length); + for (int i = 0; i < length; i++) { + out[4 + i] = (byte) (input[offset + i] ^ key); + } + return out; + } + + @Override + public byte[] decrypt(byte[] input, int offset, int length) throws IOException { + // Validate the buffer is large enough to hold the 4-byte length prefix + if (length < 4) { + throw new IOException("Encrypted buffer too short: " + length); + } + // Read the plaintext length from the 4-byte prefix + int decryptedLength = Platform.getInt(input, Platform.BYTE_ARRAY_OFFSET + offset); + // Validate bounds: the 4-byte prefix must fit inside the encrypted buffer + if (decryptedLength < 0 || decryptedLength > length - 4) { + throw new IOException( + "Invalid decrypted length: " + decryptedLength + ", encrypted length: " + length); + } + byte[] out = new byte[decryptedLength]; + for (int i = 0; i < decryptedLength; i++) { + out[i] = (byte) (input[offset + 4 + i] ^ key); + } + return out; + } + } + + /** + * Build a single batch ByteBuf as ShuffleClientImpl.pushOrMergeData does: optionally compress, + * optionally encrypt, then prepend the 16-byte batch header. + */ + private ByteBuf buildBatch( + byte[] plaintext, boolean compress, CryptoHandler cryptoHandler, CelebornConf conf) + throws IOException { + byte[] data = plaintext; + int offset = 0; + int length = plaintext.length; + + // Step 1: optionally compress (compress-then-encrypt ordering matches ShuffleClientImpl) + if (compress) { + Compressor compressor = Compressor.getCompressor(conf); + compressor.compress(data, offset, length); + data = compressor.getCompressedBuffer(); + offset = 0; + length = compressor.getCompressedTotalSize(); + } + + // Step 2: optionally encrypt the (possibly compressed) payload + if (cryptoHandler != null) { + data = cryptoHandler.encrypt(data, offset, length); + offset = 0; + length = data.length; + } + + // Step 3: prepend the 16-byte batch header [mapId|attemptId|batchId|payloadLen] + byte[] body = new byte[BATCH_HEADER_SIZE + length]; + Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, 0); // mapId + Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, 0); // attemptId + Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 8, 0); // batchId + Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, length); // payload length + System.arraycopy(data, offset, body, BATCH_HEADER_SIZE, length); + return Unpooled.wrappedBuffer(body); + } + + /** + * Create a CelebornInputStream backed by a mock TransportClient that serves the given batchBuf as + * a single chunk. + */ + private CelebornInputStream createStream( + ByteBuf batchBuf, + boolean needDecompress, + Optional cryptoHandler, + CelebornConf conf) + throws IOException, InterruptedException { + TransportClient client = mock(TransportClient.class); + PbStreamHandler pbHandler = + PbStreamHandler.newBuilder().setStreamId(1L).setNumChunks(1).build(); + // Encode the stream handler into an RPC response that CelebornInputStream expects + ByteBuffer rpcResponse = + new TransportMessage(MessageType.STREAM_HANDLER, pbHandler.toByteArray()).toByteBuffer(); + when(client.sendRpcSync(any(ByteBuffer.class), anyLong())).thenReturn(rpcResponse); + doNothing().when(client).sendRpc(any(ByteBuffer.class)); + doAnswer( + invocation -> { + ChunkReceivedCallback cb = invocation.getArgument(3); + // Serve the pre-built batch buffer immediately as chunk 0; duplicate() shares + // the underlying data without incrementing the ref count, so the stream's + // single release correctly frees the buffer. + cb.onSuccess(0, new NettyManagedBuffer(batchBuf.duplicate())); + return null; + }) + .when(client) + .fetchChunk(anyLong(), anyInt(), anyLong(), any(ChunkReceivedCallback.class)); + + TransportClientFactory clientFactory = mock(TransportClientFactory.class); + when(clientFactory.createClient(anyString(), anyInt())).thenReturn(client); + + ShuffleClient shuffleClient = mock(ShuffleClient.class); + + // PRIMARY location pointing to a single HDD partition + PartitionLocation location = + new PartitionLocation( + 0, 0, "host1", 9001, 9002, 9003, 9004, PartitionLocation.Mode.PRIMARY); + location.setStorageInfo(new StorageInfo(StorageInfo.Type.HDD, true, "/mnt/disk1")); + + ArrayList locations = new ArrayList<>(); + locations.add(location); + ArrayList handlers = new ArrayList<>(); + handlers.add(PbStreamHandler.newBuilder().setStreamId(1L).setNumChunks(1).build()); + + return CelebornInputStream.create( + conf, + clientFactory, + SHUFFLE_KEY, + locations, + handlers, + new int[] {0}, + new HashMap<>(), + new HashMap<>(), + 0, + 1L, + 0, + 100, + new ConcurrentHashMap<>(), + shuffleClient, + 1, + 1, + 0, + null, + new MetricsCallback() { + @Override + public void incBytesRead(long bytes) {} + + @Override + public void incReadTime(long time) {} + }, + needDecompress, + cryptoHandler); + } + + private byte[] readAll(CelebornInputStream stream) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + byte[] buf = new byte[4096]; + int n; + while ((n = stream.read(buf)) != -1) { + baos.write(buf, 0, n); + } + return baos.toByteArray(); + } + + @Test + public void testEncryptDecryptRoundTrip() throws IOException, InterruptedException { + byte[] plaintext = "hello, EAR round-trip without compression".getBytes(); + CelebornConf conf = new CelebornConf(); + XorCryptoHandler handler = new XorCryptoHandler((byte) 0x5A); + + // Build an encrypted batch and read it back through CelebornInputStream + ByteBuf batchBuf = buildBatch(plaintext, false, handler, conf); + try (CelebornInputStream stream = createStream(batchBuf, false, Optional.of(handler), conf)) { + assertArrayEquals(plaintext, readAll(stream)); + } + } + + @Test + public void testNoEncryptionRoundTrip() throws IOException, InterruptedException { + byte[] plaintext = "unencrypted shuffle data sanity check".getBytes(); + CelebornConf conf = new CelebornConf(); + + // Baseline: with no CryptoHandler the data flows through unchanged + ByteBuf batchBuf = buildBatch(plaintext, false, null, conf); + try (CelebornInputStream stream = createStream(batchBuf, false, Optional.empty(), conf)) { + assertArrayEquals(plaintext, readAll(stream)); + } + } + + @Test + public void testCompressThenEncryptRoundTrip() throws IOException, InterruptedException { + // Reproduce the compress-then-encrypt ordering used in ShuffleClientImpl. + byte[] plaintext = "shuffle data with compression and encryption enabled for EAR".getBytes(); + CelebornConf conf = new CelebornConf(); + // Use LZ4 (default) + conf.set(CelebornConf.SHUFFLE_COMPRESSION_CODEC().key(), "lz4"); + XorCryptoHandler handler = new XorCryptoHandler((byte) 0x3C); + + // Writer: LZ4-compress then XOR-encrypt; Reader: decrypt then decompress + ByteBuf batchBuf = buildBatch(plaintext, true, handler, conf); + try (CelebornInputStream stream = createStream(batchBuf, true, Optional.of(handler), conf)) { + assertArrayEquals(plaintext, readAll(stream)); + } + } + + @Test + public void testEncryptWithIntegrityCheckEnabled() throws IOException, InterruptedException { + // Verify that EAR + shuffle integrity check (celeborn.client.shuffle.integrityCheck.enabled) + // work together: the checksum is computed over plaintext, so decrypt-then-verify must hold. + byte[] plaintext = "integrity check should pass after decryption".getBytes(); + CelebornConf conf = new CelebornConf(); + conf.set(CelebornConf.CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED().key(), "true"); + XorCryptoHandler handler = new XorCryptoHandler((byte) 0x7F); + + // The integrity metadata (checksum) is added by CelebornInputStream over the decrypted data + ByteBuf batchBuf = buildBatch(plaintext, false, handler, conf); + try (CelebornInputStream stream = createStream(batchBuf, false, Optional.of(handler), conf)) { + assertArrayEquals(plaintext, readAll(stream)); + } + } + + @Test + public void testLargePayloadEncryptDecrypt() throws IOException, InterruptedException { + // 128 KB payload exercises buffer-boundary handling in fillBuffer() + byte[] plaintext = new byte[128 * 1024]; + for (int i = 0; i < plaintext.length; i++) plaintext[i] = (byte) (i % 251); + CelebornConf conf = new CelebornConf(); + XorCryptoHandler handler = new XorCryptoHandler((byte) 0xAB); + + ByteBuf batchBuf = buildBatch(plaintext, false, handler, conf); + try (CelebornInputStream stream = createStream(batchBuf, false, Optional.of(handler), conf)) { + assertArrayEquals(plaintext, readAll(stream)); + } + } +} diff --git a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamPeerFailoverTest.java b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamPeerFailoverTest.java index a7dd7db3178..456486b9e8c 100644 --- a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamPeerFailoverTest.java +++ b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamPeerFailoverTest.java @@ -32,6 +32,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -39,6 +40,7 @@ import org.junit.Test; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.client.TransportClient; @@ -174,7 +176,8 @@ public void testFailureWithoutPeer() throws Exception { 0, null, new TestMetricsCallback(), - false); + false, + Optional.empty()); } private void createInputStream(String primaryHost, String replicaHost) throws IOException { @@ -209,7 +212,8 @@ private void createInputStream(String primaryHost, String replicaHost) throws IO 0, null, new TestMetricsCallback(), - false); + false, + Optional.empty()); } /**