Skip to content
4 changes: 4 additions & 0 deletions client-spark/common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-crypto</artifactId>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@

import java.util.Collections;
import java.util.Map;
import java.util.Optional;

import scala.Option;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.internal.config.package$;
import org.apache.spark.memory.SparkOutOfMemoryError;

import org.apache.celeborn.client.security.CryptoHandler;
import org.apache.celeborn.reflect.DynConstructors;
import org.apache.celeborn.reflect.DynMethods;

Expand Down Expand Up @@ -96,4 +102,19 @@ public static void throwSparkOutOfMemoryError() {
}
}
}

public static Optional<CryptoHandler> getCryptoHandler(SparkConf conf) {
if (!(Boolean) conf.get(package$.MODULE$.IO_ENCRYPTION_ENABLED())) {
return Optional.empty();
}
SparkEnv env = SparkEnv.get();
if (env == null) {
return Optional.empty();
}
Option<byte[]> key = env.securityManager().getIOEncryptionKey();
if (!key.isDefined()) {
return Optional.empty();
}
return Optional.of(new SparkCryptoHandler(conf, key.get()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.celeborn;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;

import org.apache.spark.SparkConf;
import org.apache.spark.security.CryptoStreamUtils;

import org.apache.celeborn.client.security.CryptoHandler;

public class SparkCryptoHandler implements CryptoHandler {
private final SparkConf sparkConf;
private final byte[] key;

public SparkCryptoHandler(SparkConf sparkConf, byte[] key) {
this.sparkConf = sparkConf;
this.key = key;
}

@Override
public byte[] encrypt(byte[] input, int offset, int length) throws IOException {
Comment thread
akpatnam25 marked this conversation as resolved.
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
dos.writeInt(length);
try (OutputStream cos = CryptoStreamUtils.createCryptoOutputStream(dos, sparkConf, key)) {
cos.write(input, offset, length);
}
return baos.toByteArray();
}

@Override
public byte[] decrypt(byte[] input, int offset, int length) throws IOException {
ByteArrayInputStream bais = new ByteArrayInputStream(input, offset, length);
DataInputStream dis = new DataInputStream(bais);
int decryptedLength = dis.readInt();
// The encrypted payload format is: [4-byte plaintext length][ciphertext...].
// So the maximum valid decrypted length is length - 4 (the ciphertext portion).
// A value outside this range indicates corruption or a wrong key.
if (decryptedLength < 0 || decryptedLength > length - 4) {
Comment thread
akpatnam25 marked this conversation as resolved.
throw new IOException(
"Invalid decrypted length: " + decryptedLength + ", encrypted length: " + length);
}
try (DataInputStream cis =
new DataInputStream(CryptoStreamUtils.createCryptoInputStream(dis, sparkConf, key))) {
byte[] decrypted = new byte[decryptedLength];
cis.readFully(decrypted);
return decrypted;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.celeborn;

import static org.junit.Assert.*;

import java.io.IOException;
import java.security.SecureRandom;
import java.util.Arrays;

import org.apache.spark.SparkConf;
import org.apache.spark.internal.config.package$;
import org.junit.Before;
import org.junit.Test;

import org.apache.celeborn.client.security.CryptoHandler;

public class SparkCryptoHandlerSuiteJ {

private byte[] key;
private CryptoHandler handler;

@Before
public void setUp() {
key = new byte[16];
new SecureRandom().nextBytes(key);
SparkConf sparkConf = new SparkConf(false);
sparkConf.set(package$.MODULE$.IO_ENCRYPTION_ENABLED(), true);
handler = new SparkCryptoHandler(sparkConf, key);
}

@Test
public void testRoundTrip() throws IOException {
byte[] plaintext = "hello world, this is a test of encryption".getBytes();

byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length);
assertFalse(
"Encrypted output should differ from plaintext", Arrays.equals(plaintext, encrypted));

byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length);
assertArrayEquals(plaintext, decrypted);
}

@Test
public void testEncryptedDiffersFromPlaintext() throws IOException {
byte[] plaintext = "deterministic test data for comparison".getBytes();

byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length);
assertFalse(
"Encrypted output should differ from plaintext", Arrays.equals(plaintext, encrypted));
}

@Test
public void testSameDataEncryptsThenDecrypts() throws IOException {
byte[] plaintext = "same data encrypted twice".getBytes();

byte[] encrypted1 = handler.encrypt(plaintext, 0, plaintext.length);
byte[] encrypted2 = handler.encrypt(plaintext, 0, plaintext.length);

// Both should decrypt to the same plaintext
byte[] decrypted1 = handler.decrypt(encrypted1, 0, encrypted1.length);
byte[] decrypted2 = handler.decrypt(encrypted2, 0, encrypted2.length);

assertArrayEquals(plaintext, decrypted1);
assertArrayEquals(plaintext, decrypted2);
}

@Test
public void testEncryptWithOffset() throws IOException {
byte[] actual = "offset test data".getBytes();
int offset = 10;
byte[] padded = new byte[offset + actual.length + 20];
System.arraycopy(actual, 0, padded, offset, actual.length);

byte[] encrypted = handler.encrypt(padded, offset, actual.length);
byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length);

assertArrayEquals(actual, decrypted);
}

@Test
public void testDecryptWithWrongKeyFails() throws IOException {
byte[] plaintext = "secret data".getBytes();
byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length);

byte[] wrongKey = new byte[16];
new SecureRandom().nextBytes(wrongKey);
SparkConf sparkConf = new SparkConf(false);
sparkConf.set(package$.MODULE$.IO_ENCRYPTION_ENABLED(), true);
CryptoHandler wrongHandler = new SparkCryptoHandler(sparkConf, wrongKey);

byte[] decrypted = null;
try {
decrypted = wrongHandler.decrypt(encrypted, 0, encrypted.length);
} catch (IOException e) {
// acceptable — some implementations throw on wrong key
return;
}
// CryptoStreamUtils may return garbage instead of throwing
assertFalse(
"Decryption with wrong key should not produce original plaintext",
Arrays.equals(plaintext, decrypted));
}

@Test
public void testLargeData() throws IOException {
byte[] plaintext = new byte[64 * 1024]; // 64KB
new SecureRandom().nextBytes(plaintext);

byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length);
byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length);

assertArrayEquals(plaintext, decrypted);
}

@Test
public void testEmptyData() throws IOException {
byte[] encrypted = handler.encrypt(new byte[0], 0, 0);

byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length);
assertEquals(0, decrypted.length);
}
}
1 change: 1 addition & 0 deletions client-spark/spark-2-shaded/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
<include>com.google.guava:failureaccess</include>
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.apache.commons:commons-crypto</include>
<include>org.roaringbitmap:RoaringBitmap</include>
<include>commons-io:commons-io</include>
</includes>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import scala.Int;
Expand All @@ -35,6 +36,7 @@

import org.apache.celeborn.client.LifecycleManager;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.security.CryptoHandler;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.ShuffleMode;
import org.apache.celeborn.reflect.DynMethods;
Expand Down Expand Up @@ -65,6 +67,17 @@ public class SparkShuffleManager implements ShuffleManager {

private ExecutorShuffleIdTracker shuffleIdTracker = new ExecutorShuffleIdTracker();

// The IO encryption key is fixed for the app lifetime. Lazily initialized on first
// writer/reader call (not in the constructor) to ensure SparkEnv is available.
private volatile Optional<CryptoHandler> cryptoHandler = null;

private Optional<CryptoHandler> getCryptoHandler() {
if (cryptoHandler == null) {
cryptoHandler = SparkCommonUtils.getCryptoHandler(conf);
}
return cryptoHandler;
}

public SparkShuffleManager(SparkConf conf, boolean isDriver) {
SparkCommonUtils.validateAttemptConfig(conf);
this.conf = conf;
Expand Down Expand Up @@ -208,7 +221,8 @@ public <K, V> ShuffleWriter<K, V> getWriter(
h.lifecycleManagerPort(),
celebornConf,
h.userIdentifier(),
h.extension());
h.extension(),
getCryptoHandler());
if (h.stageRerunEnabled()) {
SparkUtils.addFailureListenerIfBarrierTask(client, context, h);
}
Expand Down Expand Up @@ -260,7 +274,8 @@ public <K, C> ShuffleReader<K, C> getReader(
Int.MaxValue(),
context,
celebornConf,
shuffleIdTracker);
shuffleIdTracker,
getCryptoHandler());
}
checkUserClassPathFirst(handle);
return _sortShuffleManager.getReader(handle, startPartition, endPartition, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.shuffle.celeborn

import java.io.IOException
import java.util.Optional
import java.util.concurrent.{ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import java.util.function.BiFunction
Expand All @@ -32,6 +33,7 @@ import org.apache.spark.util.collection.ExternalSorter
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.read.CelebornInputStream
import org.apache.celeborn.client.read.MetricsCallback
import org.apache.celeborn.client.security.CryptoHandler
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.{CelebornBroadcastException, CelebornIOException, CelebornRuntimeException, PartitionUnRetryAbleException}
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
Expand All @@ -45,7 +47,8 @@ class CelebornShuffleReader[K, C](
endMapIndex: Int = Int.MaxValue,
context: TaskContext,
conf: CelebornConf,
shuffleIdTracker: ExecutorShuffleIdTracker)
shuffleIdTracker: ExecutorShuffleIdTracker,
cryptoHandler: Optional[CryptoHandler] = Optional.empty())
extends ShuffleReader[K, C] with Logging {

private val dep = handle.dependency
Expand All @@ -55,7 +58,8 @@ class CelebornShuffleReader[K, C](
handle.lifecycleManagerPort,
conf,
handle.userIdentifier,
handle.extension)
handle.extension,
cryptoHandler)

private val exceptionRef = new AtomicReference[IOException]
private val encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.shuffle.celeborn

import java.util.Optional

import org.apache.spark.{ShuffleDependency, TaskContext}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
import org.apache.spark.sql.execution.UnsafeRowSerializer
import org.apache.spark.sql.execution.columnar.{CelebornBatchBuilder, CelebornColumnarBatchSerializer}

import org.apache.celeborn.client.security.CryptoHandler
import org.apache.celeborn.common.CelebornConf

class CelebornColumnarShuffleReader[K, C](
Expand All @@ -34,7 +37,8 @@ class CelebornColumnarShuffleReader[K, C](
context: TaskContext,
conf: CelebornConf,
metrics: ShuffleReadMetricsReporter,
shuffleIdTracker: ExecutorShuffleIdTracker)
shuffleIdTracker: ExecutorShuffleIdTracker,
cryptoHandler: Optional[CryptoHandler] = Optional.empty())
extends CelebornShuffleReader[K, C](
handle,
startPartition,
Expand All @@ -44,7 +48,8 @@ class CelebornColumnarShuffleReader[K, C](
context,
conf,
metrics,
shuffleIdTracker) {
shuffleIdTracker,
cryptoHandler) {

override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
val schema = CustomShuffleDependencyUtils.getSchema(dep)
Expand Down
Loading
Loading