Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2013-2024, APT Group, Department of Computer Science,
* Copyright (c) 2013-2025, APT Group, Department of Computer Science,
* The University of Manchester.
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -19,6 +19,7 @@

import static java.lang.foreign.ValueLayout.JAVA_BYTE;
import static java.lang.foreign.ValueLayout.JAVA_INT;
import static java.lang.foreign.ValueLayout.JAVA_SHORT;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
Expand All @@ -27,6 +28,7 @@

import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
import uk.ac.manchester.tornado.api.types.HalfFloat;

/**
* This class represents an array of bytes stored in native memory.
Expand Down Expand Up @@ -198,6 +200,30 @@ public void set(int index, byte value) {
segment.setAtIndex(JAVA_BYTE, baseIndex + index, value);
}

/**
* Sets the half-float value at the specified byte index within the {@link ByteArray} instance.
*
* The specified {@code byteIndex} must be aligned to a 2-byte boundary; if it is not,
* an {@link IllegalArgumentException} will be thrown. The method internally calculates
* the appropriate short index for storage and updates the underlying memory segment.
*
* @param byteIndex
* The byte index at which to set the half-float value. Must be aligned to a 2-byte boundary.
* @param value
* The {@link HalfFloat} value to be stored at the specified index.
* @throws IllegalArgumentException
* If the {@code byteIndex} is not aligned to a 2-byte boundary.
*/
public void setHalfFloat(int byteIndex, HalfFloat value) {
if (byteIndex % 2 != 0) {
throw new IllegalArgumentException("Half-float must be aligned to 2-byte boundary");
}
// Convert byte index to short index for the segment
// arrayHeaderSize (8 bytes) + byteIndex, then divide by 2 for short indexing
int shortIndex = (arrayHeaderSize + byteIndex) / 2;
segment.setAtIndex(JAVA_SHORT, shortIndex, value.getHalfFloatValue());
}

/**
* Gets the byte value stored at the specified index of the {@link ByteArray} instance.
*
Expand All @@ -209,6 +235,31 @@ public byte get(int index) {
return segment.getAtIndex(JAVA_BYTE, baseIndex + index);
}

/**
* Gets the half-float value stored at the specified byte index within the {@link ByteArray} instance.
*
* The specified {@code byteIndex} must be aligned to a 2-byte boundary; if it is not,
* an {@link IllegalArgumentException} will be thrown. The method internally calculates
* the appropriate short index for storage and retrieves the value from the underlying memory segment.
*
* @param byteIndex
* The byte index from which to retrieve the half-float value. Must be aligned to a 2-byte boundary.
* @return A {@link HalfFloat} instance containing the value stored at the specified index.
* @throws IllegalArgumentException
* If the {@code byteIndex} is not aligned to a 2-byte boundary.
*/
public HalfFloat getHalfFloat(int byteIndex) {
if (byteIndex % 2 != 0) {
throw new IllegalArgumentException("Half-float must be aligned to 2-byte boundary");
}
// Convert byte index to short index for the segment
// arrayHeaderSize (8 bytes) + byteIndex, then divide by 2 for short indexing
int shortIndex = (arrayHeaderSize + byteIndex) / 2;
short halfFloatValue = segment.getAtIndex(JAVA_SHORT, shortIndex);
return new HalfFloat(halfFloatValue);

}

/**
* Sets all the values of the {@link ByteArray} instance to zero.
*/
Expand Down
1 change: 1 addition & 0 deletions tornado-assembly/src/bin/tornado-test
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ __TEST_THE_WORLD__ = [
TestEntry("uk.ac.manchester.tornado.unittests.pointers.TestCopyDevicePointers"),
TestEntry("uk.ac.manchester.tornado.unittests.tensors.TestTensorAPIWithOnnx"),
TestEntry("uk.ac.manchester.tornado.unittests.memory.MemoryConsumptionTest"),
TestEntry("uk.ac.manchester.tornado.unittests.api.TestByteArrayTypedAccess"),

## Test for function calls - We force not to inline methods
TestEntry(testName="uk.ac.manchester.tornado.unittests.tasks.TestMultipleFunctions",
Expand Down
Loading