Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,122 @@ public float[] getDataAsFloatArray() {
"Tensor of type " + getClass().getSimpleName() + " cannot return data as float array.");
}

/**
* Copies the tensor's data into a caller-provided {@link FloatBuffer}, avoiding the per-call
* {@code float[]} allocation that {@link #getDataAsFloatArray()} performs. The destination
* buffer's position is advanced by the number of elements written; its content from the starting
* position must have at least {@link #numel()} elements of remaining capacity.
*
* <p>Useful in steady-state inference loops where the same output tensor shape is read every
* frame: pre-allocate a {@code FloatBuffer} once (e.g. via {@link #allocateFloatBuffer(int)}) and
* reuse it across calls.
*
* <p>Supported by float32 (zero-copy bulk put) and float16 (per-element half→float widening,
* matching {@link #getDataAsFloatArray()} on that subclass). For raw fp16 bits without widening,
* use {@link #copyDataInto(ShortBuffer)}.
*
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
* @throws IllegalStateException if it is called for a tensor type that does not support a float
* view.
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
* capacity.
*/
public void copyDataInto(FloatBuffer dst) {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot copy data into FloatBuffer.");
}

/**
* Copies the tensor's data into a caller-provided {@link ByteBuffer}, avoiding the per-call
* {@code byte[]} allocation that {@link #getDataAsByteArray()} performs.
*
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
* @throws IllegalStateException if it is called for a non-int8 tensor.
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
* capacity.
*/
public void copyDataInto(ByteBuffer dst) {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot copy data into ByteBuffer.");
}

/**
* Copies the tensor's data into a caller-provided {@link ByteBuffer}, avoiding the per-call
* {@code byte[]} allocation that {@link #getDataAsUnsignedByteArray()} performs. The bytes carry
* the raw uint8 bits — Java's signed {@code byte} representation, with values {@code >127}
* appearing negative; reinterpret with {@code & 0xFF} when reading.
*
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
* @throws IllegalStateException if it is called for a non-uint8 tensor.
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
* capacity.
*/
public void copyDataIntoUnsigned(ByteBuffer dst) {
throw new IllegalStateException(
"Tensor of type "
+ getClass().getSimpleName()
+ " cannot copy data into ByteBuffer (unsigned).");
}

/**
* Copies the tensor's data into a caller-provided {@link IntBuffer}, avoiding the per-call {@code
* int[]} allocation that {@link #getDataAsIntArray()} performs.
*
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
* @throws IllegalStateException if it is called for a non-int32 tensor.
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
* capacity.
*/
public void copyDataInto(IntBuffer dst) {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot copy data into IntBuffer.");
}

/**
* Copies the tensor's data into a caller-provided {@link LongBuffer}, avoiding the per-call
* {@code long[]} allocation that {@link #getDataAsLongArray()} performs.
*
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
* @throws IllegalStateException if it is called for a non-int64 tensor.
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
* capacity.
*/
public void copyDataInto(LongBuffer dst) {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot copy data into LongBuffer.");
}

/**
* Copies the tensor's data into a caller-provided {@link DoubleBuffer}, avoiding the per-call
* {@code double[]} allocation that {@link #getDataAsDoubleArray()} performs.
*
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
* @throws IllegalStateException if it is called for a non-float64 tensor.
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
* capacity.
*/
public void copyDataInto(DoubleBuffer dst) {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot copy data into DoubleBuffer.");
}

/**
* Copies the tensor's data into a caller-provided {@link ShortBuffer}, avoiding the per-call
* {@code short[]} allocation that {@link #getDataAsShortArray()} performs. For float16 tensors
* this writes the raw 16-bit half-precision bits with no widening; use {@link
* #copyDataInto(FloatBuffer)} if you want the values widened to fp32.
*
* @param dst the destination buffer; must have remaining capacity {@code >=} {@link #numel()}.
* @throws IllegalStateException if it is called for a tensor type whose backing storage is not a
* {@code ShortBuffer}.
* @throws java.nio.BufferOverflowException if {@code dst} does not have enough remaining
* capacity.
*/
public void copyDataInto(ShortBuffer dst) {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot copy data into ShortBuffer.");
}

/**
* @return a Java long array that contains the tensor data. This may be a copy or reference.
* @throws IllegalStateException if it is called for a non-int64 tensor.
Expand Down Expand Up @@ -604,6 +720,12 @@ public byte[] getDataAsUnsignedByteArray() {
return arr;
}

@Override
public void copyDataIntoUnsigned(ByteBuffer dst) {
data.rewind();
dst.put(data);
}

@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape));
Expand Down Expand Up @@ -636,6 +758,12 @@ public byte[] getDataAsByteArray() {
return arr;
}

@Override
public void copyDataInto(ByteBuffer dst) {
data.rewind();
dst.put(data);
}

@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape));
Expand Down Expand Up @@ -668,6 +796,12 @@ public int[] getDataAsIntArray() {
return arr;
}

@Override
public void copyDataInto(IntBuffer dst) {
data.rewind();
dst.put(data);
}

@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape));
Expand All @@ -690,6 +824,12 @@ public float[] getDataAsFloatArray() {
return arr;
}

@Override
public void copyDataInto(FloatBuffer dst) {
data.rewind();
dst.put(data);
}

@Override
public DType dtype() {
return DType.FLOAT;
Expand Down Expand Up @@ -732,6 +872,12 @@ public short[] getDataAsShortArray() {
return arr;
}

@Override
public void copyDataInto(ShortBuffer dst) {
data.rewind();
dst.put(data);
}

@Override
public float[] getDataAsFloatArray() {
data.rewind();
Expand All @@ -743,6 +889,21 @@ public float[] getDataAsFloatArray() {
return arr;
}

@Override
public void copyDataInto(FloatBuffer dst) {
data.rewind();
int remaining = data.remaining();
// Match the all-or-nothing semantics of bulk FloatBuffer.put(FloatBuffer):
// verify capacity up front so an undersized destination throws before any
// partial widening is observed in dst.
if (dst.remaining() < remaining) {
throw new java.nio.BufferOverflowException();
}
for (int i = 0; i < remaining; i++) {
dst.put(halfBitsToFloat(data.get()));
}
}

@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.float16)", Arrays.toString(shape));
Expand Down Expand Up @@ -800,6 +961,12 @@ public long[] getDataAsLongArray() {
return arr;
}

@Override
public void copyDataInto(LongBuffer dst) {
data.rewind();
dst.put(data);
}

@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape));
Expand Down Expand Up @@ -832,6 +999,12 @@ public double[] getDataAsDoubleArray() {
return arr;
}

@Override
public void copyDataInto(DoubleBuffer dst) {
data.rewind();
dst.put(data);
}

@Override
public String toString() {
return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape));
Expand Down
Loading