Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions tools/clang/unittests/HLSLExec/LinAlgTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
#include <optional>
#include <sstream>
#include <string>

#define STREAM_FLOAT(stream, name, value) \
stream << std::showpoint << " -D" << name << "=" << value << "F" \
<< std::noshowpoint
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this is fine, I feel like the name makes it sound more generic, as if it might be defined like so:

#define STREAM_FLOAT(value) std::showpoint << value << "F" << std::noshowpoint

And used like so:

ExtraDefs << "-DFILL_VALUE=" << STREAM_FLOAT(FillValue);

#include <variant>
#include <vector>

Expand Down Expand Up @@ -505,7 +509,7 @@ static void runSplatStore(ID3D12Device *Device,
const size_t BufferSize = Params.totalBytes();

std::stringstream ExtraDefs;
ExtraDefs << "-DFILL_VALUE=" << FillValue;
STREAM_FLOAT(ExtraDefs, "FILL_VALUE", FillValue);

std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

Expand Down Expand Up @@ -624,7 +628,7 @@ static const char ElementAccessShader[] = R"(
// flatten the 2D index into a 1D index then scale by element size
// Always store row-major and work it out in the test runner
uint coordToByteOffset(uint2 coord) {
return (coord.y * M_DIM + coord.x) * ELEM_SIZE;
return (coord.x * N_DIM + coord.y) * ELEM_SIZE;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. It sure can be confusing for x to be row and y to be column, but makes most sense in context.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was also vague in the spec until very recently. Thanks for the fix!

}

[WaveSize(4, 64)]
Expand Down Expand Up @@ -936,8 +940,8 @@ static void runMatMatMul(ID3D12Device *Device,

std::stringstream ExtraDefs;
ExtraDefs << " -DK_DIM=" << K;
ExtraDefs << " -DA_FILL=" << AFill;
ExtraDefs << " -DB_FILL=" << BFill;
STREAM_FLOAT(ExtraDefs, "A_FILL", AFill);
STREAM_FLOAT(ExtraDefs, "B_FILL", BFill);

std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

Expand Down Expand Up @@ -1018,9 +1022,9 @@ static void runMatMatMulAccum(ID3D12Device *Device,

std::stringstream ExtraDefs;
ExtraDefs << " -DK_DIM=" << K;
ExtraDefs << " -DA_FILL=" << AFill;
ExtraDefs << " -DB_FILL=" << BFill;
ExtraDefs << " -DC_FILL=" << CFill;
STREAM_FLOAT(ExtraDefs, "A_FILL", AFill);
STREAM_FLOAT(ExtraDefs, "B_FILL", BFill);
STREAM_FLOAT(ExtraDefs, "C_FILL", CFill);

std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

Expand Down Expand Up @@ -1094,8 +1098,8 @@ static void runMatAccum(ID3D12Device *Device,
const size_t BufferSize = Params.totalBytes();

std::stringstream ExtraDefs;
ExtraDefs << " -DLHS_FILL=" << LHSFill;
ExtraDefs << " -DRHS_FILL=" << RHSFill;
STREAM_FLOAT(ExtraDefs, "LHS_FILL", LHSFill);
STREAM_FLOAT(ExtraDefs, "RHS_FILL", RHSFill);

std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

Expand Down Expand Up @@ -1552,7 +1556,7 @@ static void runStoreMemory(ID3D12Device *Device,

std::stringstream ExtraDefs;
ExtraDefs << " -DOFFSET=" << 0;
ExtraDefs << " -DFILL_VALUE=" << FillValue;
STREAM_FLOAT(ExtraDefs, "FILL_VALUE", FillValue);

std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

Expand Down Expand Up @@ -1632,7 +1636,7 @@ static void runAccumulateMemory(ID3D12Device *Device,

std::stringstream ExtraDefs;
ExtraDefs << " -DOFFSET=" << 0;
ExtraDefs << " -DFILL_VALUE=" << FillValue;
STREAM_FLOAT(ExtraDefs, "FILL_VALUE", FillValue);

std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

Expand Down
Loading