VCTR
Loading...
Searching...
No Matches
AVXRegister.h
1/*
2 ==============================================================================
3 DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4
5 Copyright 2022- by sonible GmbH.
6
7 This file is part of VCTR - Versatile Container Templates Reconceptualized.
8
9 VCTR is free software: you can redistribute it and/or modify
10 it under the terms of the GNU Lesser General Public License version 3
11 only, as published by the Free Software Foundation.
12
13 VCTR is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU Lesser General Public License version 3 for more details.
17
18 You should have received a copy of the GNU Lesser General Public License
19 version 3 along with VCTR. If not, see <https://www.gnu.org/licenses/>.
20 ==============================================================================
21*/
22
23namespace vctr
24{
25
26template <class T>
28{
29 static constexpr AVXRegister broadcast (const T&) { return {}; }
30};
31
32#if VCTR_X64
33template <>
34struct AVXRegister<float>
35{
36 static constexpr size_t numElements = 8;
37
38 using NativeType = __m256;
39 __m256 value;
40
41 //==============================================================================
42 // Loading
43 // clang-format off
44 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister loadUnaligned (const float* d) { return { _mm256_loadu_ps (d) }; }
45 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister loadAligned (const float* d) { return { _mm256_load_ps (d) }; }
46 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister broadcast (float x) { return { _mm256_broadcast_ss (&x) }; }
47 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister fromSSE (SSERegister<float> a, SSERegister<float> b) { return { _mm256_set_m128 (a.value, b.value) }; }
48
49 //==============================================================================
50 // Storing
51 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void storeUnaligned (float* d) const { _mm256_storeu_ps (d, value); }
52 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void storeAligned (float* d) const { _mm256_store_ps (d, value); }
53
54 //==============================================================================
55 // Generate Compare Masks
56 template <CompareOp Op>
57 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister compare (AVXRegister a, AVXRegister b) { return { _mm256_cmp_ps (a.value, b.value, int (Op)) }; }
58
59 //==============================================================================
60 // Bit Operations
62 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister bitwiseAndNot (AVXRegister a, AVXRegister b) { return { _mm256_andnot_ps (b.value, a.value) }; }
63 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister bitwiseAnd (AVXRegister a, AVXRegister b) { return { _mm256_and_ps (a.value, b.value) }; }
64 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister bitwiseBlend (AVXRegister a, AVXRegister b, AVXRegister mask) { return { _mm256_blendv_ps (a.value, b.value, mask.value) }; }
65
66 //==============================================================================
67 // Math
68 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister floor (AVXRegister x) { return { _mm256_floor_ps (x.value) }; }
69 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister ceil (AVXRegister x) { return { _mm256_ceil_ps (x.value) }; }
70 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister mul (AVXRegister a, AVXRegister b) { return { _mm256_mul_ps (a.value, b.value) }; }
71 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister add (AVXRegister a, AVXRegister b) { return { _mm256_add_ps (a.value, b.value) }; }
72 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister sub (AVXRegister a, AVXRegister b) { return { _mm256_sub_ps (a.value, b.value) }; }
73 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister div (AVXRegister a, AVXRegister b) { return { _mm256_div_ps (a.value, b.value) }; }
74 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister max (AVXRegister a, AVXRegister b) { return { _mm256_max_ps (a.value, b.value) }; }
75 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister min (AVXRegister a, AVXRegister b) { return { _mm256_min_ps (a.value, b.value) }; }
76 VCTR_FORCEDINLINE VCTR_TARGET ("fma") static AVXRegister fma (AVXRegister a, AVXRegister b, AVXRegister c) { return { _mm256_fmadd_ps (a.value, b.value, c.value) }; }
77 VCTR_FORCEDINLINE VCTR_TARGET ("fma") static AVXRegister fms (AVXRegister a, AVXRegister b, AVXRegister c) { return { _mm256_fnmadd_ps (a.value, b.value, c.value) }; }
78
79 //==============================================================================
80 // Type conversion
81 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister<int32_t> convertToInt (AVXRegister x);
82 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister<int32_t> reinterpretAsInt (AVXRegister x);
83 // clang-format on
84};
85
86template <>
87struct AVXRegister<double>
88{
89 static constexpr size_t numElements = 4;
90
91 using NativeType = __m256d;
92 __m256d value;
93
94 //==============================================================================
95 // Loading
96 // clang-format off
97 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister loadUnaligned (const double* d) { return { _mm256_loadu_pd (d) }; }
98 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister loadAligned (const double* d) { return { _mm256_load_pd (d) }; }
99 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister broadcast (double x) { return { _mm256_broadcast_sd (&x) }; }
100 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister fromSSE (SSERegister<double> a, SSERegister<double> b) { return { _mm256_set_m128d (a.value, b.value) }; }
101
102 //==============================================================================
103 // Storing
104 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void storeUnaligned (double* d) const { _mm256_storeu_pd (d, value); }
105 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void storeAligned (double* d) const { _mm256_store_pd (d, value); }
106
107 //==============================================================================
108 // Generate Compare Masks
109 template <CompareOp Op>
110 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister compare (AVXRegister a, AVXRegister b) { return { _mm256_cmp_pd (a.value, b.value, int (Op)) }; }
111
112 //==============================================================================
113 // Bit Operations
115 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister bitwiseAndNot (AVXRegister a, AVXRegister b) { return { _mm256_andnot_pd (b.value, a.value) }; }
116 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister bitwiseAnd (AVXRegister a, AVXRegister b) { return { _mm256_and_pd (a.value, b.value) }; }
117 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister bitwiseBlend (AVXRegister a, AVXRegister b, AVXRegister mask) { return { _mm256_blendv_pd (a.value, b.value, mask.value) }; }
118
119 //==============================================================================
120 // Math
121 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister floor (AVXRegister x) { return { _mm256_floor_pd (x.value) }; }
122 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister ceil (AVXRegister x) { return { _mm256_ceil_pd (x.value) }; }
123 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister mul (AVXRegister a, AVXRegister b) { return { _mm256_mul_pd (a.value, b.value) }; }
124 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister add (AVXRegister a, AVXRegister b) { return { _mm256_add_pd (a.value, b.value) }; }
125 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister sub (AVXRegister a, AVXRegister b) { return { _mm256_sub_pd (a.value, b.value) }; }
126 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister div (AVXRegister a, AVXRegister b) { return { _mm256_div_pd (a.value, b.value) }; }
127 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister max (AVXRegister a, AVXRegister b) { return { _mm256_max_pd (a.value, b.value) }; }
128 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister min (AVXRegister a, AVXRegister b) { return { _mm256_min_pd (a.value, b.value) }; }
129 VCTR_FORCEDINLINE VCTR_TARGET ("fma") static AVXRegister fma (AVXRegister a, AVXRegister b, AVXRegister c) { return { _mm256_fmadd_pd (a.value, b.value, c.value) }; }
130 VCTR_FORCEDINLINE VCTR_TARGET ("fma") static AVXRegister fms (AVXRegister a, AVXRegister b, AVXRegister c) { return { _mm256_fnmadd_pd (a.value, b.value, c.value) }; }
131
132 //==============================================================================
133 // Type conversion
134 VCTR_FORCEDINLINE VCTR_TARGET ("avx512vl") VCTR_TARGET ("avx512dq") static AVXRegister<int64_t> convertToInt (AVXRegister x);
135 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister<int64_t> reinterpretAsInt (AVXRegister x);
136 // clang-format on
137};
138
139template <>
140struct AVXRegister<int32_t>
141{
142 static constexpr size_t numElements = 8;
143
144 using NativeType = __m256i;
145 __m256i value;
146
147 //==============================================================================
148 // Loading
149 // clang-format off
150 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister loadUnaligned (const int32_t* d) { return { _mm256_loadu_si256 (reinterpret_cast<const __m256i*> (d)) }; }
151 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister loadAligned (const int32_t* d) { return { _mm256_load_si256 (reinterpret_cast<const __m256i*> (d)) }; }
152 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister broadcast (int32_t x) { return { _mm256_set1_epi32 (x) }; }
153 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister fromSSE (SSERegister<int32_t> a, SSERegister<int32_t> b) { return { _mm256_set_m128i (a.value, b.value) }; }
154
155 //==============================================================================
156 // Storing
157 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void storeUnaligned (int32_t* d) const { _mm256_storeu_si256 (reinterpret_cast<__m256i*> (d), value); }
158 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void storeAligned (int32_t* d) const { _mm256_store_si256 (reinterpret_cast<__m256i*> (d), value); }
159
160 //==============================================================================
161 // Bit Operations
162 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister bitwiseAnd (AVXRegister a, AVXRegister b) { return { _mm256_and_si256 (a.value, b.value) }; }
163 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister bitwiseOr (AVXRegister a, AVXRegister b) { return { _mm256_or_si256 (a.value, b.value) }; }
164 // These are non AVX2 variants that might be used in functions that are not targeted AVX2 at the expense of slightly worse performance
165 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister bitwiseAndLegacy (AVXRegister a, AVXRegister b) { return { _mm256_castps_si256 (_mm256_and_ps (_mm256_castsi256_ps (a.value), _mm256_castsi256_ps (b.value))) }; }
166 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister bitwiseOrLegacy (AVXRegister a, AVXRegister b) { return { _mm256_castps_si256 (_mm256_or_ps (_mm256_castsi256_ps (a.value), _mm256_castsi256_ps (b.value))) }; }
167
168
169 //==============================================================================
170 // Math
171 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister abs (AVXRegister x) { return { _mm256_abs_epi32 (x.value) }; }
172 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister add (AVXRegister a, AVXRegister b) { return { _mm256_add_epi32 (a.value, b.value) }; }
173 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister sub (AVXRegister a, AVXRegister b) { return { _mm256_sub_epi32 (a.value, b.value) }; }
174 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister max (AVXRegister a, AVXRegister b) { return { _mm256_max_epi32 (a.value, b.value) }; }
175 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister min (AVXRegister a, AVXRegister b) { return { _mm256_min_epi32 (a.value, b.value) }; }
176
177 //==============================================================================
178 // Type conversion
179 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister<float> convertToFp (AVXRegister x) { return { _mm256_cvtepi32_ps (x.value) }; }
180 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister<float> reinterpretAsFp (AVXRegister x) { return { _mm256_castsi256_ps (x.value) }; }
181 // clang-format on
182};
183
184template <>
185struct AVXRegister<uint32_t>
186{
187 static constexpr size_t numElements = 8;
188
189 using NativeType = __m256i;
190 __m256i value;
191
192 //==============================================================================
193 // Loading
194 // clang-format off
195 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister loadUnaligned (const uint32_t* d) { return { _mm256_loadu_si256 (reinterpret_cast<const __m256i*> (d)) }; }
196 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister loadAligned (const uint32_t* d) { return { _mm256_load_si256 (reinterpret_cast<const __m256i*> (d)) }; }
197 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister broadcast (uint32_t x) { return { _mm256_set1_epi32 ((int32_t) x) }; }
198 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister fromSSE (SSERegister<uint32_t> a, SSERegister<uint32_t> b) { return { _mm256_set_m128i (a.value, b.value) }; }
199
200 //==============================================================================
201 // Storing
202 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void storeUnaligned (uint32_t* d) const { _mm256_storeu_si256 (reinterpret_cast<__m256i*> (d), value); }
203 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void storeAligned (uint32_t* d) const { _mm256_store_si256 (reinterpret_cast<__m256i*> (d), value); }
204
205 //==============================================================================
206 // Bit Operations
207
208 //==============================================================================
209 // Math
210 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister add (AVXRegister a, AVXRegister b) { return { _mm256_add_epi32 (a.value, b.value) }; }
211 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister sub (AVXRegister a, AVXRegister b) { return { _mm256_sub_epi32 (a.value, b.value) }; }
212 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister max (AVXRegister a, AVXRegister b) { return { _mm256_max_epu32 (a.value, b.value) }; }
213 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister min (AVXRegister a, AVXRegister b) { return { _mm256_min_epu32 (a.value, b.value) }; }
214 // clang-format on
215};
216
217template <>
218struct AVXRegister<int64_t>
219{
220 static constexpr size_t numElements = 4;
221
222 using NativeType = __m256i;
223 __m256i value;
224
225 //==============================================================================
226 // Loading
227 // clang-format off
228 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister loadUnaligned (const int64_t* d) { return { _mm256_loadu_si256 (reinterpret_cast<const __m256i*> (d)) }; }
229 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister loadAligned (const int64_t* d) { return { _mm256_load_si256 (reinterpret_cast<const __m256i*> (d)) }; }
230 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister broadcast (int64_t x) { return { _mm256_set1_epi64x (x) }; }
231 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister fromSSE (SSERegister<int64_t> a, SSERegister<int64_t> b) { return { _mm256_set_m128i (a.value, b.value) }; }
232
233 //==============================================================================
234 // Storing
235 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void storeUnaligned (int64_t* d) const { _mm256_storeu_si256 (reinterpret_cast<__m256i*> (d), value); }
236 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void storeAligned (int64_t* d) const { _mm256_store_si256 (reinterpret_cast<__m256i*> (d), value); }
237
238 //==============================================================================
239 // Bit Operations
240 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister bitwiseAnd (AVXRegister a, AVXRegister b) { return { _mm256_and_si256 (a.value, b.value) }; }
241 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister bitwiseOr (AVXRegister a, AVXRegister b) { return { _mm256_or_si256 (a.value, b.value) }; }
242 // These are non AVX2 variants that might be used in functions that are not targeted AVX2 at the expense of slightly worse performance
243 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister bitwiseAndLegacy (AVXRegister a, AVXRegister b) { return { _mm256_castpd_si256 (_mm256_and_pd (_mm256_castsi256_pd (a.value), _mm256_castsi256_pd (b.value))) }; }
244 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister bitwiseOrLegacy (AVXRegister a, AVXRegister b) { return { _mm256_castpd_si256 (_mm256_or_pd (_mm256_castsi256_pd (a.value), _mm256_castsi256_pd (b.value))) }; }
245
246 //==============================================================================
247 // Math
248 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister add (AVXRegister a, AVXRegister b) { return { _mm256_add_epi64 (a.value, b.value) }; }
249 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister sub (AVXRegister a, AVXRegister b) { return { _mm256_sub_epi64 (a.value, b.value) }; }
250
251 //==============================================================================
252 // Type conversion
253 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister<double> convertToFp (AVXRegister x) { return { _mm256_cvtepi64_pd (x.value) }; }
254 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister<double> reinterpretAsFp (AVXRegister x) { return { _mm256_castsi256_pd (x.value) }; }
255 // clang-format on
256};
257
258template <>
259struct AVXRegister<uint64_t>
260{
261 static constexpr size_t numElements = 4;
262
263 using NativeType = __m256i;
264 __m256i value;
265
266 //==============================================================================
267 // Loading
268 // clang-format off
269 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister loadUnaligned (const uint64_t* d) { return { _mm256_loadu_si256 (reinterpret_cast<const __m256i*> (d)) }; }
270 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister loadAligned (const uint64_t* d) { return { _mm256_load_si256 (reinterpret_cast<const __m256i*> (d)) }; }
271 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister broadcast (uint64_t x) { return { _mm256_set1_epi64x ((int64_t) x) }; }
272 VCTR_FORCEDINLINE VCTR_TARGET ("avx") static AVXRegister fromSSE (SSERegister<uint64_t> a, SSERegister<uint64_t> b) { return { _mm256_set_m128i (a.value, b.value) }; }
273
274 //==============================================================================
275 // Storing
276 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void storeUnaligned (uint64_t* d) const { _mm256_storeu_si256 (reinterpret_cast<__m256i*> (d), value); }
277 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void storeAligned (uint64_t* d) const { _mm256_store_si256 (reinterpret_cast<__m256i*> (d), value); }
278
279 //==============================================================================
280 // Bit Operations
281
282 //==============================================================================
283 // Math
284 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister add (AVXRegister a, AVXRegister b) { return { _mm256_add_epi64 (a.value, b.value) }; }
285 VCTR_FORCEDINLINE VCTR_TARGET ("avx2") static AVXRegister sub (AVXRegister a, AVXRegister b) { return { _mm256_sub_epi64 (a.value, b.value) }; }
286 // clang-format on
287};
288
289inline AVXRegister<int32_t> AVXRegister<float>::convertToInt (AVXRegister x) { return { _mm256_cvtps_epi32 (x.value) }; }
290inline AVXRegister<int32_t> AVXRegister<float>::reinterpretAsInt (AVXRegister x) { return { _mm256_castps_si256 (x.value) }; }
291inline AVXRegister<int64_t> AVXRegister<double>::convertToInt (AVXRegister x) { return { _mm256_cvtpd_epi64 (x.value) }; }
292inline AVXRegister<int64_t> AVXRegister<double>::reinterpretAsInt (AVXRegister x) { return { _mm256_castpd_si256 (x.value) }; }
293#endif
294
295} // namespace vctr
constexpr ExpressionChainBuilder< expressions::Max > max
Computes the maximum value of the source values.
Definition: Max.h:198
constexpr ExpressionChainBuilder< expressions::Abs > abs
Computes the absolute value of the source values.
Definition: Abs.h:135
constexpr ExpressionChainBuilder< expressions::Min > min
Computes the minimum value of the source values.
Definition: Min.h:198
The main namespace of the VCTR project.
Definition: Array.h:24
Definition: AVXRegister.h:28