Skip to content

AIMIGRAPHX-578 Reintroduce blaze for better ref gemm performance#4619

Closed
kahmed10 wants to merge 10 commits intodevelopfrom
blaze_gemm_impl
Closed

AIMIGRAPHX-578 Reintroduce blaze for better ref gemm performance#4619
kahmed10 wants to merge 10 commits intodevelopfrom
blaze_gemm_impl

Conversation

@kahmed10
Copy link
Collaborator

Motivation

Our naive gemm implementation is much slower than using an optimized CPU library.

Technical Details

Reintroduces blaze, also converts to a higher precision if it's not supported.

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

@kahmed10 kahmed10 changed the title Reintroduce blaze for better ref gemm performance [CI Test] Reintroduce blaze for better ref gemm performance Feb 18, 2026
#endif

template <class T, class U, class F>
void gemm(tensor_view<T> cmat, tensor_view<U> amat, tensor_view<U> bmat, F alpha, F beta)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be all moved to the .cpp file. The gemm function should be changed to use argument as inputs and the alpha and beta parameters should be removed as they are not used.

std::cerr << "[blaze gemm] " << m_size << "x" << k_size << "x" << n_size
<< " batches=" << num_batches << " (native)" << std::endl;

for(std::size_t batch = 0; batch < num_batches; batch++)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be a par_for.

{
offset += (remaining % s.lens()[d - 1]) * s.strides()[d - 1];
remaining /= s.lens()[d - 1];
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the index method from the shape class to compute the the offset.

}

template <class T, class Func>
void with_blaze_mat(T* ptr, std::size_t rows, std::size_t cols, mat_order order, Func&& func)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should take a tensor_view: void with_blaze_mat(tensor_view<T> view, mat_order order, Func&& func). The rows and cols can come from the shape.

A class can be created to do the slicing of the tensor:

struct batch_slicer
{
    batch_slicer(const shape& mat_shape)
    {
        // Compute the inner_shape and outer_shape
    }

    template<class T>
    tensor_view<T> extract(tensor_view<T> view, std::size_t batch)
    {
        auto offset = outer_shape.index(batch);
        return make_view(inner_shape, view.data()+offset);
    }

    shape inner_shape;
    shape outer_shape;
};

}

template <class T, class Func>
void with_blaze_mat(T* ptr, std::size_t rows, std::size_t cols, mat_order order, Func&& func)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it would be simpler to just used the code from here for make_mat and visit_mat:

static void visit_mat(tensor_view<T> x, F f)

@codecov
Copy link

codecov bot commented Feb 18, 2026

Codecov Report

❌ Patch coverage is 84.94624% with 14 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/include/migraphx/gemm.hpp 84.95% 14 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4619      +/-   ##
===========================================
- Coverage    92.22%   92.16%   -0.06%     
===========================================
  Files          567      567              
  Lines        27910    28002      +92     
===========================================
+ Hits         25739    25806      +67     
- Misses        2171     2196      +25     
Files with missing lines Coverage Δ
src/include/migraphx/gemm.hpp 77.68% <84.95%> (-22.32%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kahmed10 kahmed10 changed the title [CI Test] Reintroduce blaze for better ref gemm performance AIMIGRAPHX-578 Reintroduce blaze for better ref gemm performance Feb 18, 2026
@kahmed10 kahmed10 marked this pull request as ready for review February 24, 2026 20:24
@kahmed10 kahmed10 requested a review from causten as a code owner February 24, 2026 20:24
is_mat_layout_supported(bmat.get_shape()) and
is_mat_layout_supported(cmat.get_shape()))
{
for(std::size_t batch = 0; batch < num_batches; batch++)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use par_for. Also move this to another function so it can be reused in the copy case.

copy_2d(a_buf.data(), k_size, std::size_t{1},
amat.data() + a_batch.index(batch), a_row_stride, a_col_stride,
m_size, k_size);
copy_2d(b_buf.data(), n_size, std::size_t{1},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copy should happen outside of the batch loop. Just use std::copy:

std::vector<float> a_buf(amat.elements());
std::copy(amat.begin(), amat.end(), a_buf.begin());
auto amat_flat = make_view(amat.get_shape().as_standard().with_type(shape::float_type), a_buf.begin());

And then call into the function that handles is_mat_layout_supported.

@causten
Copy link
Collaborator

causten commented Feb 24, 2026

Test Batch Rate new
2d1d0d
Rate old
3486a8
Diff Compare
torchvision-resnet50 64 3,106.62 3,108.62 -0.06%
torchvision-resnet50_fp16 64 6,484.03 6,488.35 -0.07%
torchvision-densenet121 32 2,390.24 2,392.18 -0.08%
torchvision-densenet121_fp16 32 4,034.00 4,024.45 0.24%
torchvision-inceptionv3 32 1,653.58 1,653.77 -0.01%
torchvision-inceptionv3_fp16 32 2,503.13 2,506.44 -0.13%
cadene-inceptionv4 16 788.80 788.85 -0.01%
cadene-resnext64x4 16 774.09 773.71 0.05%
slim-mobilenet 64 8,035.45 8,033.27 0.03%
slim-nasnetalarge 64 215.88 215.84 0.02%
slim-resnet50v2 64 3,238.70 3,238.88 -0.01%
bert-mrpc-onnx 8 nan 1,123.79 nan%
bert-mrpc-tf 1 455.78 458.15 -0.52%
pytorch-examples-wlang-gru 1 329.39 336.60 -2.14%
pytorch-examples-wlang-lstm 1 419.26 420.47 -0.29%
torchvision-resnet50_1 1 743.21 736.77 0.87%
cadene-dpn92_1 1 407.05 410.69 -0.89%
cadene-resnext101_1 1 352.81 352.58 0.07%
onnx-taau-downsample 1 391.33 389.77 0.40%
dlrm-criteoterabyte 1 32.49 32.48 0.04%
dlrm-criteoterabyte_fp16 1 50.33 50.30 0.06%
agentmodel 1 9,908.54 10,058.97 -1.50%
unet_fp16 2 nan 55.88 nan%
resnet50v1_fp16 1 917.31 925.46 -0.88%
resnet50v1_int8 1 880.77 882.65 -0.21%
bert_base_cased_fp16 64 nan 1,088.52 nan%
bert_large_uncased_fp16 32 nan 343.42 nan%
bert_large_fp16 1 nan 199.40 nan%
distilgpt2_fp16 16 nan 2,066.91 nan%
yolov5s 1 546.04 545.30 0.14%
tinyllama 1 43.98 43.95 0.06%
vicuna-fastchat 1 42.76 42.79 -0.09%
whisper-tiny-encoder 1 nan 404.95 nan%
whisper-tiny-decoder 1 nan 402.05 nan%
llama2_7b 1 19.16 19.16 0.00%
qwen1.5-7b 1 23.47 23.46 0.05%
phi3-3.8b 1 26.65 26.69 -0.17%
llama3-8b 1 21.71 21.73 -0.10%
whisper-large-encoder 1 nan 10.11 nan%
whisper-large-decoder 1 nan 101.33 nan%
mistral-7b 1 23.71 23.70 0.02%
FLUX.1-schnell 1 nan 728.23 nan%
nan nan nan nan nan%
nan nan nan nan nan%

This build is not recommended to merge 🔴

@causten
Copy link
Collaborator

causten commented Feb 24, 2026


❌bert-mrpc-onnx: ERROR - check error outputterminate called recursively
terminate called recursively
terminate called recursively


     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

❌unet: ERROR - check error outputterminate called after throwing an instance of 'std::runtime_error'
terminate called recursively
terminate called recursively
terminate called recursively
terminate called recursively
terminate called recursively


     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

❌bert_base_cased_fp16: ERROR - check error outputterminate called after throwing an instance of 'std::runtime_errorterminate called recursively
terminate called recursively
terminate called recursively
terminate called recursively


❌bert_large_uncased_fp16: ERROR - check error outputterminate called after throwing an instance of 'std::runtime_error'
terminate called recursively
terminate called recursively
terminate called recursively
terminate called recursively
terminate called recursively
terminate called recursively


❌bert_large: ERROR - check error outputterminate called recursively
terminate called recursively
terminate called recursively


     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

❌whisper-tiny-encoder: ERROR - check error outputterminate called recursively
terminate called after throwing an instance of 'std::runtime_error'
terminate called recursively


❌whisper-tiny-decoder: ERROR - check error outputterminate called recursively
terminate called recursively
terminate called after throwing an instance of 'std::runtime_error'
terminate called recursively


❌distilgpt2_fp16: ERROR - check error outputterminate called recursively
terminate called after throwing an instance of 'std::runtime_error'
terminate called recursively
terminate called recursively
terminate called recursively


     ✅ llama2_7b: PASSED: MIGraphX meets tolerance

     ✅ qwen1.5-7b: PASSED: MIGraphX meets tolerance

     ✅ phi3-3.8b: PASSED: MIGraphX meets tolerance

     ✅ llama3-8b: PASSED: MIGraphX meets tolerance

❌whisper-large-decoder: ERROR - check error outputterminate called recursively
terminate called recursively


     ✅ mistral-7b: PASSED: MIGraphX meets tolerance

❌FLUX.1-schnell: ERROR - check error outputterminate called recursively
terminate called recursively

@kahmed10
Copy link
Collaborator Author

Will work on #4631 instead

@kahmed10 kahmed10 closed this Feb 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants