AIMIGRAPHX-578 Reintroduce blaze for better ref gemm performance#4619
AIMIGRAPHX-578 Reintroduce blaze for better ref gemm performance#4619
Conversation
src/include/migraphx/gemm.hpp
Outdated
| #endif | ||
|
|
||
| template <class T, class U, class F> | ||
| void gemm(tensor_view<T> cmat, tensor_view<U> amat, tensor_view<U> bmat, F alpha, F beta) |
There was a problem hiding this comment.
This should be all moved to the .cpp file. The gemm function should be changed to use argument as inputs and the alpha and beta parameters should be removed as they are not used.
src/include/migraphx/gemm.hpp
Outdated
| std::cerr << "[blaze gemm] " << m_size << "x" << k_size << "x" << n_size | ||
| << " batches=" << num_batches << " (native)" << std::endl; | ||
|
|
||
| for(std::size_t batch = 0; batch < num_batches; batch++) |
There was a problem hiding this comment.
This should probably be a par_for.
src/include/migraphx/gemm.hpp
Outdated
| { | ||
| offset += (remaining % s.lens()[d - 1]) * s.strides()[d - 1]; | ||
| remaining /= s.lens()[d - 1]; | ||
| } |
There was a problem hiding this comment.
Use the index method from the shape class to compute the the offset.
src/include/migraphx/gemm.hpp
Outdated
| } | ||
|
|
||
| template <class T, class Func> | ||
| void with_blaze_mat(T* ptr, std::size_t rows, std::size_t cols, mat_order order, Func&& func) |
There was a problem hiding this comment.
This should take a tensor_view: void with_blaze_mat(tensor_view<T> view, mat_order order, Func&& func). The rows and cols can come from the shape.
A class can be created to do the slicing of the tensor:
struct batch_slicer
{
batch_slicer(const shape& mat_shape)
{
// Compute the inner_shape and outer_shape
}
template<class T>
tensor_view<T> extract(tensor_view<T> view, std::size_t batch)
{
auto offset = outer_shape.index(batch);
return make_view(inner_shape, view.data()+offset);
}
shape inner_shape;
shape outer_shape;
};
src/include/migraphx/gemm.hpp
Outdated
| } | ||
|
|
||
| template <class T, class Func> | ||
| void with_blaze_mat(T* ptr, std::size_t rows, std::size_t cols, mat_order order, Func&& func) |
There was a problem hiding this comment.
Actually, it would be simpler to just used the code from here for make_mat and visit_mat:
AMDMIGraphX/src/targets/ref/gemm.cpp
Line 51 in fefbe99
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #4619 +/- ##
===========================================
- Coverage 92.22% 92.16% -0.06%
===========================================
Files 567 567
Lines 27910 28002 +92
===========================================
+ Hits 25739 25806 +67
- Misses 2171 2196 +25
🚀 New features to boost your workflow:
|
| is_mat_layout_supported(bmat.get_shape()) and | ||
| is_mat_layout_supported(cmat.get_shape())) | ||
| { | ||
| for(std::size_t batch = 0; batch < num_batches; batch++) |
There was a problem hiding this comment.
This should use par_for. Also move this to another function so it can be reused in the copy case.
| copy_2d(a_buf.data(), k_size, std::size_t{1}, | ||
| amat.data() + a_batch.index(batch), a_row_stride, a_col_stride, | ||
| m_size, k_size); | ||
| copy_2d(b_buf.data(), n_size, std::size_t{1}, |
There was a problem hiding this comment.
The copy should happen outside of the batch loop. Just use std::copy:
std::vector<float> a_buf(amat.elements());
std::copy(amat.begin(), amat.end(), a_buf.begin());
auto amat_flat = make_view(amat.get_shape().as_standard().with_type(shape::float_type), a_buf.begin());And then call into the function that handles is_mat_layout_supported.
This build is not recommended to merge 🔴 |
❌bert-mrpc-onnx: ERROR - check error outputterminate called recursivelyterminate called recursively terminate called recursively ❌unet: ERROR - check error outputterminate called after throwing an instance of 'std::runtime_error'terminate called recursively terminate called recursively terminate called recursively terminate called recursively terminate called recursively ❌bert_base_cased_fp16: ERROR - check error outputterminate called after throwing an instance of 'std::runtime_errorterminate called recursivelyterminate called recursively terminate called recursively terminate called recursively ❌bert_large_uncased_fp16: ERROR - check error outputterminate called after throwing an instance of 'std::runtime_error'terminate called recursively terminate called recursively terminate called recursively terminate called recursively terminate called recursively terminate called recursively ❌bert_large: ERROR - check error outputterminate called recursivelyterminate called recursively terminate called recursively ❌whisper-tiny-encoder: ERROR - check error outputterminate called recursivelyterminate called after throwing an instance of 'std::runtime_error' terminate called recursively ❌whisper-tiny-decoder: ERROR - check error outputterminate called recursivelyterminate called recursively terminate called after throwing an instance of 'std::runtime_error' terminate called recursively ❌distilgpt2_fp16: ERROR - check error outputterminate called recursivelyterminate called after throwing an instance of 'std::runtime_error' terminate called recursively terminate called recursively terminate called recursively ❌whisper-large-decoder: ERROR - check error outputterminate called recursivelyterminate called recursively ❌FLUX.1-schnell: ERROR - check error outputterminate called recursivelyterminate called recursively |
|
Will work on #4631 instead |
Motivation
Our naive gemm implementation is much slower than using an optimized CPU library.
Technical Details
Reintroduces blaze, also converts to a higher precision if it's not supported.
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable