diff --git a/src/Simd/SimdAvx2SynetSoftmax16b.cpp b/src/Simd/SimdAvx2SynetSoftmax16b.cpp index fd98154850..622022c0db 100644 --- a/src/Simd/SimdAvx2SynetSoftmax16b.cpp +++ b/src/Simd/SimdAvx2SynetSoftmax16b.cpp @@ -77,208 +77,286 @@ namespace Simd //-------------------------------------------------------------------------------------------------- - //SIMD_INLINE void SynetSoftmax16b31(const Avx2::Exp& exp, __m256 buf[3]) - //{ - // __m256 max = _mm256_max_ps(buf[0], _mm256_max_ps(buf[1], buf[2])); - // buf[0] = exp.Exponent(_mm256_sub_ps(buf[0], max)); - // buf[1] = exp.Exponent(_mm256_sub_ps(buf[1], max)); - // buf[2] = exp.Exponent(_mm256_sub_ps(buf[2], max)); - // __m256 sum = _mm256_add_ps(buf[0], _mm256_add_ps(buf[1], buf[2])); - // buf[0] = _mm256_div_ps(buf[0], sum); - // buf[1] = _mm256_div_ps(buf[1], sum); - // buf[2] = _mm256_div_ps(buf[2], sum); - //} + SIMD_INLINE void SynetSoftmax16b31Load(const uint16_t* src, __m256 buf[3]) + { + static const __m256i SFL00 = SIMD_MM256_SETR_EPI8( + -1, -1, 0x0, 0x1, -1, -1, 0x6, 0x7, -1, -1, 0xC, 0xD, -1, -1, -1, -1, + -1, -1, 0x8, 0x9, -1, -1, 0xE, 0xF, -1, -1, -1, -1, -1, -1, -1, -1); + static const __m256i SFL01 = SIMD_MM256_SETR_EPI8( + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x2, 0x3, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x4, 0x5, -1, -1, 0xA, 0xB); + static const __m256i SFL10 = SIMD_MM256_SETR_EPI8( + -1, -1, 0x2, 0x3, -1, -1, 0x8, 0x9, -1, -1, 0xE, 0xF, -1, -1, -1, -1, + -1, -1, 0xA, 0xB, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + static const __m256i SFL11 = SIMD_MM256_SETR_EPI8( + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x4, 0x5, + -1, -1, -1, -1, -1, -1, 0x0, 0x1, -1, -1, 0x6, 0x7, -1, -1, 0xC, 0xD); + static const __m256i SFL20 = SIMD_MM256_SETR_EPI8( + -1, -1, 0x4, 0x5, -1, -1, 0xA, 0xB, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 0xC, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + static const __m256i SFL21 = SIMD_MM256_SETR_EPI8( + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x0, 0x1, -1, -1, 0x6, 0x7, + -1, -1, -1, -1, -1, -1, 0x2, 0x3, -1, -1, 0x8, 0x9, -1, -1, 0xE, 0xF); + __m256i s01 = _mm256_loadu_si256((__m256i*)(src + 0)); + __m256i s12 = _mm256_loadu_si256((__m256i*)(src + 8)); + buf[0] = _mm256_castsi256_ps(_mm256_or_si256(_mm256_shuffle_epi8(s01, SFL00), _mm256_shuffle_epi8(s12, SFL01))); + buf[1] = _mm256_castsi256_ps(_mm256_or_si256(_mm256_shuffle_epi8(s01, SFL10), _mm256_shuffle_epi8(s12, SFL11))); + buf[2] = _mm256_castsi256_ps(_mm256_or_si256(_mm256_shuffle_epi8(s01, SFL20), _mm256_shuffle_epi8(s12, SFL21))); + } - //void SynetSoftmax16b31(const uint16_t* src, size_t outer, uint16_t* dst) - //{ - // Avx2::Exp exp; - // __m256 buf[3]; - // size_t aligned = Simd::AlignLo(outer, F); - // for (size_t o = 0; o < aligned; o += F) - // { - // buf[0] = Avx2::Gather<3>(src + 0); - // buf[1] = Avx2::Gather<3>(src + 1); - // buf[2] = Avx2::Gather<3>(src + 2); - // SynetSoftmax16b31(exp, buf); - // Scater<3>(dst + 0, buf[0]); - // Scater<3>(dst + 1, buf[1]); - // Scater<3>(dst + 2, buf[2]); - // src += 3 * F; - // dst += 3 * F; - // } - // if (aligned < outer) - // { - // size_t tail = outer - aligned; - // buf[0] = Gather<3>(src + 0, tail); - // buf[1] = Gather<3>(src + 1, tail); - // buf[2] = Gather<3>(src + 2, tail); - // SynetSoftmax16b31(exp, buf); - // Scater<3>(dst + 0, buf[0], tail); - // Scater<3>(dst + 1, buf[1], tail); - // Scater<3>(dst + 2, buf[2], tail); - // } - //} + SIMD_INLINE void SynetSoftmax16b31(const Avx2::Exp& exp, __m256 buf[3]) + { + __m256 max = _mm256_max_ps(buf[0], _mm256_max_ps(buf[1], buf[2])); + buf[0] = exp.Exponent(_mm256_sub_ps(buf[0], max)); + buf[1] = exp.Exponent(_mm256_sub_ps(buf[1], max)); + buf[2] = exp.Exponent(_mm256_sub_ps(buf[2], max)); + __m256 sum = _mm256_add_ps(buf[0], _mm256_add_ps(buf[1], buf[2])); + buf[0] = _mm256_div_ps(buf[0], sum); + buf[1] = _mm256_div_ps(buf[1], sum); + buf[2] = _mm256_div_ps(buf[2], sum); + } - //SIMD_INLINE void LoadTansp8x8(const uint16_t* src, size_t count, float* dst, __m256& max) - //{ - // __m256 a0, a1, a2, a3, a4, a5, a6, a7, b0, b1, b2, b3, b4, b5, b6, b7; + SIMD_INLINE void SynetSoftmax16b31Load(const uint16_t* src, size_t size, __m256 dst[3]) + { + SIMD_ALIGNED(32) uint16_t buf[A]; + for (size_t i = 0; i < size; i += 1) + { + buf[0 * 8 + i] = src[i * 3 + 0]; + buf[1 * 8 + i] = src[i * 3 + 1]; + buf[2 * 8 + i] = src[i * 3 + 2]; + } + __m128i b01 = _mm_loadu_si128((__m128i*)buf); + dst[0] = BFloat16ToFloat32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i*)buf + 0))); + dst[1] = BFloat16ToFloat32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i*)buf + 1))); + dst[2] = BFloat16ToFloat32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i*)buf + 2))); + } + + SIMD_INLINE void SynetSoftmax16b31Save(const __m256 src[3], uint16_t* dst) + { + __m256i s01 = Float32ToBFloat16Interlived(src[0], src[1]); + __m256i s2 = Float32ToBFloat16(src[2]); - // a0 = _mm256_loadu_ps(src + 0 * count); - // a1 = _mm256_loadu_ps(src + 1 * count); - // a2 = _mm256_loadu_ps(src + 2 * count); - // a3 = _mm256_loadu_ps(src + 3 * count); - // a4 = _mm256_loadu_ps(src + 4 * count); - // a5 = _mm256_loadu_ps(src + 5 * count); - // a6 = _mm256_loadu_ps(src + 6 * count); - // a7 = _mm256_loadu_ps(src + 7 * count); + static const __m256i SFL020 = SIMD_MM256_SETR_EPI8( + 0x0, 0x1, 0x2, 0x3, -1, -1, 0x4, 0x5, 0x6, 0x7, -1, -1, 0x8, 0x9, 0xA, 0xB, + 0x6, 0x7, -1, -1, 0x8, 0x9, 0xA, 0xB, -1, -1, 0xC, 0xD, 0xE, 0xF, -1, -1); + static const __m256i SFL021 = SIMD_MM256_SETR_EPI8( + -1, -1, -1, -1, 0x0, 0x1, -1, -1, -1, -1, 0x4, 0x5, -1, -1, -1, -1, + -1, -1, 0x4, 0x5, -1, -1, -1, -1, 0x8, 0x9, -1, -1, -1, -1, 0xC, 0xD); + __m256i d02 = _mm256_or_si256(_mm256_shuffle_epi8(s01, SFL020), _mm256_shuffle_epi8(s2, SFL021)); - // b0 = _mm256_unpacklo_ps(a0, a2); - // b1 = _mm256_unpacklo_ps(a1, a3); - // b2 = _mm256_unpackhi_ps(a0, a2); - // b3 = _mm256_unpackhi_ps(a1, a3); - // b4 = _mm256_unpacklo_ps(a4, a6); - // b5 = _mm256_unpacklo_ps(a5, a7); - // b6 = _mm256_unpackhi_ps(a4, a6); - // b7 = _mm256_unpackhi_ps(a5, a7); + static const __m256i SFL10 = SIMD_MM256_SETR_EPI8( + -1, -1, 0xC, 0xD, 0xE, 0xF, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, 0x0, 0x1, 0x2, 0x3, -1, -1, 0x4, 0x5); + static const __m256i SFL11 = SIMD_MM256_SETR_EPI8( + 0x8, 0x9, -1, -1, -1, -1, 0xC, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x0, 0x1, -1, -1); + __m256i d1 = _mm256_or_si256(_mm256_shuffle_epi8(s01, SFL10), _mm256_shuffle_epi8(s2, SFL11)); - // a0 = _mm256_unpacklo_ps(b0, b1); - // a1 = _mm256_unpackhi_ps(b0, b1); - // a2 = _mm256_unpacklo_ps(b2, b3); - // a3 = _mm256_unpackhi_ps(b2, b3); - // a4 = _mm256_unpacklo_ps(b4, b5); - // a5 = _mm256_unpackhi_ps(b4, b5); - // a6 = _mm256_unpacklo_ps(b6, b7); - // a7 = _mm256_unpackhi_ps(b6, b7); + _mm_storeu_si128((__m128i*)dst + 0, _mm256_extractf128_si256(d02, 0)); + _mm_storeu_si128((__m128i*)dst + 1, _mm_or_si128(_mm256_extractf128_si256(d1, 0), _mm256_extractf128_si256(d1, 1))); + _mm_storeu_si128((__m128i*)dst + 2, _mm256_extractf128_si256(d02, 1)); + } + + SIMD_INLINE void SynetSoftmax16b31Save(const __m256 src[3], size_t size, uint16_t* dst) + { + SIMD_ALIGNED(16) uint16_t buf[A]; + _mm256_storeu_si256((__m256i*)buf + 0, Float32ToBFloat16(src[0], src[1])); + _mm256_storeu_si256((__m256i*)buf + 1, Float32ToBFloat16(src[2], src[2])); + for (size_t i = 0; i < size; i += 1) + { + dst[i * 3 + 0] = buf[0 * 8 + i]; + dst[i * 3 + 1] = buf[1 * 8 + i]; + dst[i * 3 + 2] = buf[2 * 8 + i]; + } + } - // b0 = _mm256_permute2f128_ps(a0, a4, 0x20); - // b1 = _mm256_permute2f128_ps(a1, a5, 0x20); - // b2 = _mm256_permute2f128_ps(a2, a6, 0x20); - // b3 = _mm256_permute2f128_ps(a3, a7, 0x20); - // b4 = _mm256_permute2f128_ps(a0, a4, 0x31); - // b5 = _mm256_permute2f128_ps(a1, a5, 0x31); - // b6 = _mm256_permute2f128_ps(a2, a6, 0x31); - // b7 = _mm256_permute2f128_ps(a3, a7, 0x31); + void SynetSoftmax16b31(const uint16_t* src, size_t outer, uint16_t* dst) + { + Avx2::Exp exp; + __m256 buf[3]; + size_t aligned = Simd::AlignLo(outer, F); + for (size_t o = 0; o < aligned; o += F) + { + SynetSoftmax16b31Load(src, buf); + SynetSoftmax16b31(exp, buf); + SynetSoftmax16b31Save(buf, dst); + src += 3 * F; + dst += 3 * F; + } + if (aligned < outer) + { + size_t tail = outer - aligned; + SynetSoftmax16b31Load(src, tail, buf); + SynetSoftmax16b31(exp, buf); + SynetSoftmax16b31Save(buf, tail, dst); + } + } - // max = _mm256_max_ps(max, b0); - // max = _mm256_max_ps(max, b1); - // max = _mm256_max_ps(max, b2); - // max = _mm256_max_ps(max, b3); - // max = _mm256_max_ps(max, b4); - // max = _mm256_max_ps(max, b5); - // max = _mm256_max_ps(max, b6); - // max = _mm256_max_ps(max, b7); + //-------------------------------------------------------------------------------------------------- - // _mm256_storeu_ps(dst + 0 * F, b0); - // _mm256_storeu_ps(dst + 1 * F, b1); - // _mm256_storeu_ps(dst + 2 * F, b2); - // _mm256_storeu_ps(dst + 3 * F, b3); - // _mm256_storeu_ps(dst + 4 * F, b4); - // _mm256_storeu_ps(dst + 5 * F, b5); - // _mm256_storeu_ps(dst + 6 * F, b6); - // _mm256_storeu_ps(dst + 7 * F, b7); - //} + SIMD_INLINE void LoadTansp8x8(const uint16_t* src, size_t count, float* dst, __m256& max) + { + __m256 a0, a1, a2, a3, a4, a5, a6, a7, b0, b1, b2, b3, b4, b5, b6, b7; - //SIMD_INLINE void StoreTansp8x8(const float* src, __m256 k, uint16_t* dst, size_t count) - //{ - // __m256 a0, a1, a2, a3, a4, a5, a6, a7, b0, b1, b2, b3, b4, b5, b6, b7; + a0 = BFloat16ToFloat32(_mm_loadu_si128((__m128i*)(src + 0 * count))); + a1 = BFloat16ToFloat32(_mm_loadu_si128((__m128i*)(src + 1 * count))); + a2 = BFloat16ToFloat32(_mm_loadu_si128((__m128i*)(src + 2 * count))); + a3 = BFloat16ToFloat32(_mm_loadu_si128((__m128i*)(src + 3 * count))); + a4 = BFloat16ToFloat32(_mm_loadu_si128((__m128i*)(src + 4 * count))); + a5 = BFloat16ToFloat32(_mm_loadu_si128((__m128i*)(src + 5 * count))); + a6 = BFloat16ToFloat32(_mm_loadu_si128((__m128i*)(src + 6 * count))); + a7 = BFloat16ToFloat32(_mm_loadu_si128((__m128i*)(src + 7 * count))); - // a0 = _mm256_mul_ps(_mm256_loadu_ps(src + 0 * F), k); - // a1 = _mm256_mul_ps(_mm256_loadu_ps(src + 1 * F), k); - // a2 = _mm256_mul_ps(_mm256_loadu_ps(src + 2 * F), k); - // a3 = _mm256_mul_ps(_mm256_loadu_ps(src + 3 * F), k); - // a4 = _mm256_mul_ps(_mm256_loadu_ps(src + 4 * F), k); - // a5 = _mm256_mul_ps(_mm256_loadu_ps(src + 5 * F), k); - // a6 = _mm256_mul_ps(_mm256_loadu_ps(src + 6 * F), k); - // a7 = _mm256_mul_ps(_mm256_loadu_ps(src + 7 * F), k); + b0 = _mm256_unpacklo_ps(a0, a2); + b1 = _mm256_unpacklo_ps(a1, a3); + b2 = _mm256_unpackhi_ps(a0, a2); + b3 = _mm256_unpackhi_ps(a1, a3); + b4 = _mm256_unpacklo_ps(a4, a6); + b5 = _mm256_unpacklo_ps(a5, a7); + b6 = _mm256_unpackhi_ps(a4, a6); + b7 = _mm256_unpackhi_ps(a5, a7); - // b0 = _mm256_unpacklo_ps(a0, a2); - // b1 = _mm256_unpacklo_ps(a1, a3); - // b2 = _mm256_unpackhi_ps(a0, a2); - // b3 = _mm256_unpackhi_ps(a1, a3); - // b4 = _mm256_unpacklo_ps(a4, a6); - // b5 = _mm256_unpacklo_ps(a5, a7); - // b6 = _mm256_unpackhi_ps(a4, a6); - // b7 = _mm256_unpackhi_ps(a5, a7); + a0 = _mm256_unpacklo_ps(b0, b1); + a1 = _mm256_unpackhi_ps(b0, b1); + a2 = _mm256_unpacklo_ps(b2, b3); + a3 = _mm256_unpackhi_ps(b2, b3); + a4 = _mm256_unpacklo_ps(b4, b5); + a5 = _mm256_unpackhi_ps(b4, b5); + a6 = _mm256_unpacklo_ps(b6, b7); + a7 = _mm256_unpackhi_ps(b6, b7); - // a0 = _mm256_unpacklo_ps(b0, b1); - // a1 = _mm256_unpackhi_ps(b0, b1); - // a2 = _mm256_unpacklo_ps(b2, b3); - // a3 = _mm256_unpackhi_ps(b2, b3); - // a4 = _mm256_unpacklo_ps(b4, b5); - // a5 = _mm256_unpackhi_ps(b4, b5); - // a6 = _mm256_unpacklo_ps(b6, b7); - // a7 = _mm256_unpackhi_ps(b6, b7); + b0 = _mm256_permute2f128_ps(a0, a4, 0x20); + b1 = _mm256_permute2f128_ps(a1, a5, 0x20); + b2 = _mm256_permute2f128_ps(a2, a6, 0x20); + b3 = _mm256_permute2f128_ps(a3, a7, 0x20); + b4 = _mm256_permute2f128_ps(a0, a4, 0x31); + b5 = _mm256_permute2f128_ps(a1, a5, 0x31); + b6 = _mm256_permute2f128_ps(a2, a6, 0x31); + b7 = _mm256_permute2f128_ps(a3, a7, 0x31); - // b0 = _mm256_permute2f128_ps(a0, a4, 0x20); - // b1 = _mm256_permute2f128_ps(a1, a5, 0x20); - // b2 = _mm256_permute2f128_ps(a2, a6, 0x20); - // b3 = _mm256_permute2f128_ps(a3, a7, 0x20); - // b4 = _mm256_permute2f128_ps(a0, a4, 0x31); - // b5 = _mm256_permute2f128_ps(a1, a5, 0x31); - // b6 = _mm256_permute2f128_ps(a2, a6, 0x31); - // b7 = _mm256_permute2f128_ps(a3, a7, 0x31); + max = _mm256_max_ps(max, b0); + max = _mm256_max_ps(max, b1); + max = _mm256_max_ps(max, b2); + max = _mm256_max_ps(max, b3); + max = _mm256_max_ps(max, b4); + max = _mm256_max_ps(max, b5); + max = _mm256_max_ps(max, b6); + max = _mm256_max_ps(max, b7); - // _mm256_storeu_ps(dst + 0 * count, b0); - // _mm256_storeu_ps(dst + 1 * count, b1); - // _mm256_storeu_ps(dst + 2 * count, b2); - // _mm256_storeu_ps(dst + 3 * count, b3); - // _mm256_storeu_ps(dst + 4 * count, b4); - // _mm256_storeu_ps(dst + 5 * count, b5); - // _mm256_storeu_ps(dst + 6 * count, b6); - // _mm256_storeu_ps(dst + 7 * count, b7); - //} + _mm256_storeu_ps(dst + 0 * F, b0); + _mm256_storeu_ps(dst + 1 * F, b1); + _mm256_storeu_ps(dst + 2 * F, b2); + _mm256_storeu_ps(dst + 3 * F, b3); + _mm256_storeu_ps(dst + 4 * F, b4); + _mm256_storeu_ps(dst + 5 * F, b5); + _mm256_storeu_ps(dst + 6 * F, b6); + _mm256_storeu_ps(dst + 7 * F, b7); + } - //void SynetSoftmax16bX1(const uint16_t* src, size_t outer, size_t count, uint16_t* dst) - //{ - // size_t o = 0, c = 0, outerF = AlignLo(outer, F), countF = AlignLo(count, F); - // Array32f buf(AlignHi(count, F) * F); - // Exp exp; - // for (; o < outerF; o += F) - // { - // __m256 _max = _mm256_set1_ps(-FLT_MAX); - // for (c = 0; c < countF; c += F) - // LoadTansp8x8(src + c, count, buf.data + c * F, _max); - // if (c < count) - // { - // c = count - F; - // LoadTansp8x8(src + c, count, buf.data + c * F, _max); - // } - // __m256 _sum = _mm256_setzero_ps(); - // for (size_t c = 0; c < count; ++c) - // { - // __m256 _exp = exp.Exponent(_mm256_sub_ps(_mm256_loadu_ps(buf.data + c * F), _max)); - // _sum = _mm256_add_ps(_sum, _exp); - // _mm256_storeu_ps(buf.data + c * F, _exp); - // } - // __m256 _k = _mm256_div_ps(_mm256_set1_ps(1.0f), _sum); - // for (c = 0; c < countF; c += F) - // StoreTansp8x8(buf.data + c * F, _k, dst + c, count); - // if (c < count) - // { - // c = count - F; - // StoreTansp8x8(buf.data + c * F, _k, dst + c, count); - // } - // src += count * F; - // dst += count * F; - // } - // for (; o < outer; ++o) - // { - // float max = src[0]; - // for (size_t c = 1; c < count; ++c) - // max = Simd::Max(max, src[c]); - // float sum = 0; - // for (size_t c = 0; c < count; ++c) - // { - // dst[c] = ::exp(src[c] - max); - // sum += dst[c]; - // } - // float k = 1.0f / sum; - // for (size_t c = 0; c < count; ++c) - // dst[c] *= k; - // src += count; - // dst += count; - // } - //} + SIMD_INLINE void StoreTansp8x8(const float* src, __m256 k, uint16_t* dst, size_t count) + { + __m256 a0, a1, a2, a3, a4, a5, a6, a7, b0, b1, b2, b3, b4, b5, b6, b7; + + a0 = _mm256_mul_ps(_mm256_loadu_ps(src + 0 * F), k); + a1 = _mm256_mul_ps(_mm256_loadu_ps(src + 1 * F), k); + a2 = _mm256_mul_ps(_mm256_loadu_ps(src + 2 * F), k); + a3 = _mm256_mul_ps(_mm256_loadu_ps(src + 3 * F), k); + a4 = _mm256_mul_ps(_mm256_loadu_ps(src + 4 * F), k); + a5 = _mm256_mul_ps(_mm256_loadu_ps(src + 5 * F), k); + a6 = _mm256_mul_ps(_mm256_loadu_ps(src + 6 * F), k); + a7 = _mm256_mul_ps(_mm256_loadu_ps(src + 7 * F), k); + + b0 = _mm256_unpacklo_ps(a0, a2); + b1 = _mm256_unpacklo_ps(a1, a3); + b2 = _mm256_unpackhi_ps(a0, a2); + b3 = _mm256_unpackhi_ps(a1, a3); + b4 = _mm256_unpacklo_ps(a4, a6); + b5 = _mm256_unpacklo_ps(a5, a7); + b6 = _mm256_unpackhi_ps(a4, a6); + b7 = _mm256_unpackhi_ps(a5, a7); + + a0 = _mm256_unpacklo_ps(b0, b1); + a1 = _mm256_unpackhi_ps(b0, b1); + a2 = _mm256_unpacklo_ps(b2, b3); + a3 = _mm256_unpackhi_ps(b2, b3); + a4 = _mm256_unpacklo_ps(b4, b5); + a5 = _mm256_unpackhi_ps(b4, b5); + a6 = _mm256_unpacklo_ps(b6, b7); + a7 = _mm256_unpackhi_ps(b6, b7); + + b0 = _mm256_permute2f128_ps(a0, a4, 0x20); + b1 = _mm256_permute2f128_ps(a1, a5, 0x20); + b2 = _mm256_permute2f128_ps(a2, a6, 0x20); + b3 = _mm256_permute2f128_ps(a3, a7, 0x20); + b4 = _mm256_permute2f128_ps(a0, a4, 0x31); + b5 = _mm256_permute2f128_ps(a1, a5, 0x31); + b6 = _mm256_permute2f128_ps(a2, a6, 0x31); + b7 = _mm256_permute2f128_ps(a3, a7, 0x31); + + _mm_storeu_si128((__m128i*)(dst + 0 * count), PackFloat32ToBFloat16(b0)); + _mm_storeu_si128((__m128i*)(dst + 1 * count), PackFloat32ToBFloat16(b1)); + _mm_storeu_si128((__m128i*)(dst + 2 * count), PackFloat32ToBFloat16(b2)); + _mm_storeu_si128((__m128i*)(dst + 3 * count), PackFloat32ToBFloat16(b3)); + _mm_storeu_si128((__m128i*)(dst + 4 * count), PackFloat32ToBFloat16(b4)); + _mm_storeu_si128((__m128i*)(dst + 5 * count), PackFloat32ToBFloat16(b5)); + _mm_storeu_si128((__m128i*)(dst + 6 * count), PackFloat32ToBFloat16(b6)); + _mm_storeu_si128((__m128i*)(dst + 7 * count), PackFloat32ToBFloat16(b7)); + } + + void SynetSoftmax16bX1(const uint16_t* src, size_t outer, size_t count, uint16_t* dst) + { + size_t o = 0, c = 0, outerF = AlignLo(outer, F), countF = AlignLo(count, F); + Array32f buf(AlignHi(count, F) * F); + Exp exp; + for (; o < outerF; o += F) + { + __m256 _max = _mm256_set1_ps(-FLT_MAX); + for (c = 0; c < countF; c += F) + LoadTansp8x8(src + c, count, buf.data + c * F, _max); + if (c < count) + { + c = count - F; + LoadTansp8x8(src + c, count, buf.data + c * F, _max); + } + __m256 _sum = _mm256_setzero_ps(); + for (size_t c = 0; c < count; ++c) + { + __m256 _exp = exp.Exponent(_mm256_sub_ps(_mm256_loadu_ps(buf.data + c * F), _max)); + _sum = _mm256_add_ps(_sum, _exp); + _mm256_storeu_ps(buf.data + c * F, _exp); + } + __m256 _k = _mm256_div_ps(_mm256_set1_ps(1.0f), _sum); + for (c = 0; c < countF; c += F) + StoreTansp8x8(buf.data + c * F, _k, dst + c, count); + if (c < count) + { + c = count - F; + StoreTansp8x8(buf.data + c * F, _k, dst + c, count); + } + src += count * F; + dst += count * F; + } + for (; o < outer; ++o) + { + for (size_t c = 0; c < count; ++c) + buf[c] = Base::BFloat16ToFloat32(src[c]); + + float max = buf[0]; + for (size_t c = 1; c < count; ++c) + max = Simd::Max(max, buf[c]); + float sum = 0; + for (size_t c = 0; c < count; ++c) + { + buf[c] = ::exp(buf[c] - max); + sum += buf[c]; + } + float k = 1.0f / sum; + for (size_t c = 0; c < count; ++c) + dst[c] = Base::Float32ToBFloat16(buf[c] * k); + src += count; + dst += count; + } + } void SynetSoftmax16b(const uint16_t* src, size_t outer, size_t count, size_t inner, uint16_t* dst) { @@ -286,14 +364,12 @@ namespace Simd { if (count == 2) SynetSoftmax16b21(src, outer, dst); - //else if (count == 3) - // SynetSoftmax16b31(src, outer, dst); - //else if(count < 8) - // Sse41::SynetSoftmax16bX1(src, outer, count, dst); - //else - // SynetSoftmax16bX1(src, outer, count, dst); + else if (count == 3) + SynetSoftmax16b31(src, outer, dst); + else if(count < 8) + Sse41::SynetSoftmax16bX1(src, outer, count, dst); else - Sse41::SynetSoftmax16b(src, outer, count, inner, dst); + SynetSoftmax16bX1(src, outer, count, dst); } else { diff --git a/src/Simd/SimdBFloat16.h b/src/Simd/SimdBFloat16.h index 787e203c81..de07bbbea6 100644 --- a/src/Simd/SimdBFloat16.h +++ b/src/Simd/SimdBFloat16.h @@ -176,11 +176,22 @@ namespace Simd return _mm256_castsi256_ps(UnpackU16(K_ZERO, value)); } + SIMD_INLINE __m256 BFloat16ToFloat32(__m128i value) + { + return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(value), Base::Bf16::SHIFT)); + } + SIMD_INLINE __m256i Float32ToBFloat16(__m256 lo, __m256 hi) { return _mm256_permute4x64_epi64(_mm256_packus_epi32(Float32ToBFloat16(lo), Float32ToBFloat16(hi)), 0xD8); } + SIMD_INLINE __m128i PackFloat32ToBFloat16(__m256 f32) + { + __m256i b16 = Float32ToBFloat16(f32); + return _mm256_castsi256_si128(_mm256_permute4x64_epi64(_mm256_packus_epi32(b16, b16), 0xD8)); + } + SIMD_INLINE __m256 BFloat16ToFloat32Even(__m256i value) { return _mm256_castsi256_ps(_mm256_slli_epi32(value, Base::Bf16::SHIFT)); diff --git a/src/Test/TestSynetSoftmax.cpp b/src/Test/TestSynetSoftmax.cpp index d7b90737a1..70c4ac27d8 100644 --- a/src/Test/TestSynetSoftmax.cpp +++ b/src/Test/TestSynetSoftmax.cpp @@ -172,7 +172,7 @@ namespace Test Tensor32f src32f(shape, format); Tensor16u src16b(shape, format); - FillRandom(src32f.Data(), src32f.Size(), -10.0, 10.0); + FillRandom(src32f.Data(), src32f.Size(), -1.0, 1.0); SimdFloat32ToBFloat16(src32f.Data(), src32f.Size(), src16b.Data()); Tensor32f dst32f1(shape, format);