VCTR
Loading...
Searching...
No Matches
FastExp.h
1/*
2 ==============================================================================
3 DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4
5 Copyright 2022- by sonible GmbH.
6
7 This file is part of VCTR - Versatile Container Templates Reconceptualized.
8
9 VCTR is free software: you can redistribute it and/or modify
10 it under the terms of the GNU Lesser General Public License version 3
11 only, as published by the Free Software Foundation.
12
13 VCTR is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU Lesser General Public License version 3 for more details.
17
18 You should have received a copy of the GNU Lesser General Public License
19 version 3 along with VCTR. If not, see <https://www.gnu.org/licenses/>.
20 ==============================================================================
21*/
22
23namespace vctr::expressions
24{
25
26//==============================================================================
28template <size_t extent, class SrcType>
29requires is::realOrComplexFloatNumber<ValueType<SrcType>>
31{
32public:
33 using value_type = ValueType<SrcType>;
34
35 VCTR_COMMON_UNARY_EXPRESSION_MEMBERS (FastExp, src)
36
37 VCTR_FORCEDINLINE constexpr value_type operator[] (size_t i) const
38 {
39 return (Const1680 + src[i] * (Const840 + src[i] * (Const180 + src[i] * (Const20 + src[i])))) / (Const1680 + src[i] * (ConstMinus840 + src[i] * (Const180 + src[i] * (ConstMinus20 + src[i]))));
40 }
41
42 //==============================================================================
43 // AVX Implementation
44 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void prepareAVXEvaluation() const
45 requires (has::prepareAVXEvaluation<SrcType> && Expression::CommonElement::isRealFloat)
46 {
47 src.prepareAVXEvaluation();
48
49 SIMDConst20.avx = Expression::AVX::broadcast (Const20);
50 SIMDConst180.avx = Expression::AVX::broadcast (Const180);
51 SIMDConst840.avx = Expression::AVX::broadcast (Const840);
52 SIMDConst1680.avx = Expression::AVX::broadcast (Const1680);
53 SIMDConstMinus20.avx = Expression::AVX::broadcast (ConstMinus20);
54 SIMDConstMinus840.avx = Expression::AVX::broadcast (ConstMinus840);
55 }
56
57 VCTR_FORCEDINLINE VCTR_TARGET ("fma") AVXRegister<value_type> getAVX (size_t i) const
58 requires (archX64 && has::getAVX<SrcType> && Expression::allElementTypesSame && Expression::CommonElement::isRealFloat)
59 {
60 const auto in = src.getAVX (i);
61
62 auto numerator = Expression::AVX::add (in, SIMDConst20.avx);
63 numerator = Expression::AVX::mul (numerator, in);
64 numerator = Expression::AVX::add (numerator, SIMDConst180.avx);
65 numerator = Expression::AVX::mul (numerator, in);
66 numerator = Expression::AVX::add (numerator, SIMDConst840.avx);
67 numerator = Expression::AVX::mul (numerator, in);
68 numerator = Expression::AVX::add (numerator, SIMDConst1680.avx);
69
70 auto denominator = Expression::AVX::add (in, SIMDConstMinus20.avx);
71 denominator = Expression::AVX::mul (denominator, in);
72 denominator = Expression::AVX::add (denominator, SIMDConst180.avx);
73 denominator = Expression::AVX::mul (denominator, in);
74 denominator = Expression::AVX::add (denominator, SIMDConstMinus840.avx);
75 denominator = Expression::AVX::mul (denominator, in);
76 denominator = Expression::AVX::add (denominator, SIMDConst1680.avx);
77
78 return Expression::AVX::div (numerator, denominator);
79 }
80
81 //==============================================================================
82 // SSE Implementation
83 VCTR_FORCEDINLINE VCTR_TARGET ("sse4.1") void prepareSSEEvaluation() const
84 requires (has::prepareSSEEvaluation<SrcType> && Expression::CommonElement::isRealFloat)
85 {
86 src.prepareSSEEvaluation();
87
88 SIMDConst20.sse = Expression::SSE::broadcast (Const20);
89 SIMDConst180.sse = Expression::SSE::broadcast (Const180);
90 SIMDConst840.sse = Expression::SSE::broadcast (Const840);
91 SIMDConst1680.sse = Expression::SSE::broadcast (Const1680);
92 SIMDConstMinus20.sse = Expression::SSE::broadcast (ConstMinus20);
93 SIMDConstMinus840.sse = Expression::SSE::broadcast (ConstMinus840);
94 }
95
96 VCTR_FORCEDINLINE VCTR_TARGET ("sse4.1") SSERegister<value_type> getSSE (size_t i) const
97 requires (archX64 && has::getSSE<SrcType> && Expression::allElementTypesSame && Expression::CommonElement::isRealFloat)
98 {
99 const auto in = src.getSSE (i);
100
101 auto numerator = Expression::SSE::add (in, SIMDConst20.sse);
102 numerator = Expression::SSE::mul (numerator, in);
103 numerator = Expression::SSE::add (numerator, SIMDConst180.sse);
104 numerator = Expression::SSE::mul (numerator, in);
105 numerator = Expression::SSE::add (numerator, SIMDConst840.sse);
106 numerator = Expression::SSE::mul (numerator, in);
107 numerator = Expression::SSE::add (numerator, SIMDConst1680.sse);
108
109 auto denominator = Expression::SSE::add (in, SIMDConstMinus20.sse);
110 denominator = Expression::SSE::mul (denominator, in);
111 denominator = Expression::SSE::add (denominator, SIMDConst180.sse);
112 denominator = Expression::SSE::mul (denominator, in);
113 denominator = Expression::SSE::add (denominator, SIMDConstMinus840.sse);
114 denominator = Expression::SSE::mul (denominator, in);
115 denominator = Expression::SSE::add (denominator, SIMDConst1680.sse);
116
117 return Expression::SSE::div (numerator, denominator);
118 }
119
120 //==============================================================================
121 // Neon Implementation
122 void prepareNeonEvaluation() const
123 requires (archARM && has::prepareNeonEvaluation<SrcType> && Expression::CommonElement::isRealFloat)
124 {
125 src.prepareNeonEvaluation();
126
127 SIMDConst20.neon = Expression::Neon::broadcast (Const20);
128 SIMDConst180.neon = Expression::Neon::broadcast (Const180);
129 SIMDConst840.neon = Expression::Neon::broadcast (Const840);
130 SIMDConst1680.neon = Expression::Neon::broadcast (Const1680);
131 SIMDConstMinus20.neon = Expression::Neon::broadcast (ConstMinus20);
132 SIMDConstMinus840.neon = Expression::Neon::broadcast (ConstMinus840);
133 }
134
135 NeonRegister<value_type> getNeon (size_t i) const
136 requires (archARM && has::getNeon<SrcType> && Expression::allElementTypesSame && Expression::CommonElement::isRealFloat)
137 {
138 const auto in = src.getNeon (i);
139
140 auto numerator = Expression::Neon::add (in, SIMDConst20.neon);
141 numerator = Expression::Neon::mul (numerator, in);
142 numerator = Expression::Neon::add (numerator, SIMDConst180.neon);
143 numerator = Expression::Neon::mul (numerator, in);
144 numerator = Expression::Neon::add (numerator, SIMDConst840.neon);
145 numerator = Expression::Neon::mul (numerator, in);
146 numerator = Expression::Neon::add (numerator, SIMDConst1680.neon);
147
148 auto denominator = Expression::Neon::add (in, SIMDConstMinus20.neon);
149 denominator = Expression::Neon::mul (denominator, in);
150 denominator = Expression::Neon::add (denominator, SIMDConst180.neon);
151 denominator = Expression::Neon::mul (denominator, in);
152 denominator = Expression::Neon::add (denominator, SIMDConstMinus840.neon);
153 denominator = Expression::Neon::mul (denominator, in);
154 denominator = Expression::Neon::add (denominator, SIMDConst1680.neon);
155
156 return Expression::Neon::div (numerator, denominator);
157 }
158
159private:
160 static constexpr value_type Const20 = value_type (20);
161 static constexpr value_type Const180 = value_type (180);
162 static constexpr value_type Const840 = value_type (840);
163 static constexpr value_type Const1680 = value_type (1680);
164 static constexpr value_type ConstMinus20 = value_type (-20);
165 static constexpr value_type ConstMinus840 = value_type (-840);
166
167 mutable SIMDRegisterUnion<Expression> SIMDConst20 {};
168 mutable SIMDRegisterUnion<Expression> SIMDConst180 {};
169 mutable SIMDRegisterUnion<Expression> SIMDConst840 {};
170 mutable SIMDRegisterUnion<Expression> SIMDConst1680 {};
171 mutable SIMDRegisterUnion<Expression> SIMDConstMinus20 {};
172 mutable SIMDRegisterUnion<Expression> SIMDConstMinus840 {};
173};
174
175namespace detail
176{
177
178template <std::floating_point>
179struct FastExp2Constants {};
180
181template <>
182struct FastExp2Constants<float>
183{
184 static constexpr int mantissaBits = 23;
185 static constexpr float minExpo = -126.0f; // exponent of minimum binary32 normal
186 static constexpr float expoBias = 127.0f; // binary32 exponent bias
187 static constexpr float a = -0x1.6e7592p+2f;
188 static constexpr float b = 0x1.bba764p+4f;
189 static constexpr float c = 0x1.35ed00p+2f;
190 static constexpr float d = 0x1.f5e546p-2f;
191 static constexpr float e = 1 << mantissaBits;
192};
193
194template <>
195struct FastExp2Constants<double>
196{
197 static constexpr int mantissaBits = 52;
198 static constexpr double minExpo = -1022.0; // exponent of minimum binary64 normal
199 static constexpr double expoBias = 1023.0; // binary64 exponent bias
200 static constexpr double a = -0x1.6e75d58p+2;
201 static constexpr double b = 0x1.bba7414p+4;
202 static constexpr double c = 0x1.35eccbap+2;
203 static constexpr double d = 0x1.f5e53c2p-2;
204 static constexpr double e = 1LL << mantissaBits;
205};
206}
207
208//==============================================================================
215template <size_t extent, class SrcType>
216requires is::realFloatNumber<ValueType<SrcType>>
218{
219public:
220 using value_type = ValueType<SrcType>;
221
222 using Constants = detail::FastExp2Constants<value_type>;
223
224 VCTR_COMMON_UNARY_EXPRESSION_MEMBERS (FastExp2, src)
225
226 VCTR_FORCEDINLINE value_type operator[] (size_t i) const
227 {
228 auto p = src[i];
229
230 p = std::max (p, Constants::minExpo);
231
232 auto w = std::floor (p);
233 auto z = p - w;
234
235 auto approx = Constants::a + Constants::b / (Constants::c - z) - Constants::d * z;
236
237 auto resi = IntType (Constants::e * (w + Constants::expoBias + approx));
238
239
240 return bitCast<value_type> (resi);
241 }
242
243 //==============================================================================
244 // AVX Implementation
245 VCTR_FORCEDINLINE VCTR_TARGET ("avx") void prepareAVXEvaluation() const
246 requires (has::prepareAVXEvaluation<SrcType> && Expression::CommonElement::isFloat)
247 {
248 src.prepareAVXEvaluation();
249
250 minExpo.avx = Expression::AVX::broadcast (Constants::minExpo);
251 expoBias.avx = Expression::AVX::broadcast (Constants::expoBias);
252 c_a.avx = Expression::AVX::broadcast (Constants::a);
253 c_b.avx = Expression::AVX::broadcast (Constants::b);
254 c_c.avx = Expression::AVX::broadcast (Constants::c);
255 c_d.avx = Expression::AVX::broadcast (Constants::d);
256 c_e.avx = Expression::AVX::broadcast (Constants::e);
257 }
258
259 VCTR_FORCEDINLINE VCTR_TARGET ("fma") AVXRegister<value_type> getAVX (size_t i) const
260 requires (archX64 && has::getAVX<SrcType> && Expression::allElementTypesSame && Expression::CommonElement::isFloat)
261 {
262 auto in = src.getAVX (i);
263
264 in = Expression::AVX::max (in, minExpo.avx);
265
266 auto w = Expression::AVX::floor (in);
267 auto z = Expression::AVX::sub (in, w);
268
269 auto approx = Expression::AVX::sub (Expression::AVX::add (c_a.avx ,
270 Expression::AVX::div (c_b.avx,
271 Expression::AVX::sub (c_c.avx, z))),
272 Expression::AVX::mul (c_d.avx,
273 z)
274 );
275
276 auto resi = Expression::AVX::mul (c_e.avx, Expression::AVX::add (w, Expression::AVX::add (expoBias.avx, approx)));
277
278 // ConvertToInt requires AVX512 features for double registers, so this implementation is constrained to float
279 return AVXRegister<IntType>::reinterpretAsFp (Expression::AVX::convertToInt (resi));
280 }
281
282 //==============================================================================
283 // SSE Implementation
284 VCTR_FORCEDINLINE VCTR_TARGET ("sse4.1") void prepareSSEEvaluation() const
285 requires (has::prepareSSEEvaluation<SrcType> && Expression::CommonElement::isFloat)
286 {
287 src.prepareSSEEvaluation();
288
289 minExpo.sse = Expression::SSE::broadcast (Constants::minExpo);
290 expoBias.sse = Expression::SSE::broadcast (Constants::expoBias);
291 c_a.sse = Expression::SSE::broadcast (Constants::a);
292 c_b.sse = Expression::SSE::broadcast (Constants::b);
293 c_c.sse = Expression::SSE::broadcast (Constants::c);
294 c_d.sse = Expression::SSE::broadcast (Constants::d);
295 c_e.sse = Expression::SSE::broadcast (Constants::e);
296 }
297
298 VCTR_FORCEDINLINE VCTR_TARGET ("sse4.1") SSERegister<value_type> getSSE (size_t i) const
299 requires (archX64 && has::getSSE<SrcType> && Expression::allElementTypesSame && Expression::CommonElement::isFloat)
300 {
301 auto in = src.getSSE (i);
302
303 in = Expression::SSE::max (in, minExpo.sse);
304
305 auto w = Expression::SSE::floor (in);
306 auto z = Expression::SSE::sub (in, w);
307
308 auto approx = Expression::SSE::sub (Expression::SSE::add (c_a.sse ,
309 Expression::SSE::div (c_b.sse,
310 Expression::SSE::sub (c_c.sse, z))),
311 Expression::SSE::mul (c_d.sse,
312 z)
313 );
314
315 auto resi = Expression::SSE::mul (c_e.sse, Expression::SSE::add (w, Expression::SSE::add (expoBias.sse, approx)));
316
317 // ConvertToInt requires AVX512 features for double registers, so this implementation is constrained to float
318 return SSERegister<IntType>::reinterpretAsFp (Expression::SSE::convertToInt (resi));
319 }
320
321 //==============================================================================
322 // Neon Implementation
323 void prepareNeonEvaluation() const
324 requires (archARM && has::prepareNeonEvaluation<SrcType> && Expression::CommonElement::isRealFloat)
325 {
326 src.prepareNeonEvaluation();
327
328 minExpo.neon = Expression::Neon::broadcast (Constants::minExpo);
329 expoBias.neon = Expression::Neon::broadcast (Constants::expoBias);
330 c_a.neon = Expression::Neon::broadcast (Constants::a);
331 c_b.neon = Expression::Neon::broadcast (Constants::b);
332 c_c.neon = Expression::Neon::broadcast (Constants::c);
333 c_d.neon = Expression::Neon::broadcast (Constants::d);
334 c_e.neon = Expression::Neon::broadcast (Constants::e);
335 }
336
337 NeonRegister<value_type> getNeon (size_t i) const
338 requires (archARM && has::getNeon<SrcType> && Expression::allElementTypesSame && Expression::CommonElement::isRealFloat)
339 {
340 auto in = src.getNeon (i);
341
342 in = Expression::Neon::max (in, minExpo.neon);
343
344 auto w = Expression::Neon::floor (in);
345 auto z = Expression::Neon::sub (in, w);
346
347 auto approx = Expression::Neon::sub (Expression::Neon::add (c_a.neon ,
348 Expression::Neon::div (c_b.neon,
349 Expression::Neon::sub (c_c.neon, z))),
350 Expression::Neon::mul (c_d.neon,
351 z)
352 );
353
354 auto resi = Expression::Neon::mul (c_e.neon, Expression::Neon::add (w, Expression::Neon::add (expoBias.neon, approx)));
355
356 return NeonRegister<IntType>::reinterpretAsFp (Expression::Neon::convertToInt (resi));
357 }
358
359private:
360 using IntType = std::conditional_t<std::same_as<float, value_type>, int32_t, int64_t>;
361
362 mutable SIMDRegisterUnion<Expression> minExpo;
363 mutable SIMDRegisterUnion<Expression> expoBias;
369};
370
371} // namespace vctr::expressions
372
373namespace vctr
374{
375
383
393
394} // namespace vctr
Calculates a fast approximation for the exp2 function.
Definition: FastExp.h:218
Calculates a fast approximation for the exp function.
Definition: FastExp.h:31
Constrains a type to have a member function getAVX (size_t) const.
Definition: ContainerAndExpressionConcepts.h:92
Constrains a type to have a member function getNeon (size_t) const.
Definition: ContainerAndExpressionConcepts.h:84
Constrains a type to have a member function getSSE (size_t) const.
Definition: ContainerAndExpressionConcepts.h:100
Constrains a type to have a member function prepareAVXEvaluation() const.
Definition: ContainerAndExpressionConcepts.h:88
Constrains a type to have a member function prepareNeonEvaluation() const.
Definition: ContainerAndExpressionConcepts.h:80
Constrains a type to have a member function prepareSSEEvaluation() const.
Definition: ContainerAndExpressionConcepts.h:96
constexpr ExpressionChainBuilder< expressions::FastExp > fastExp
A fast approximation of the exp function, using only basic algebraic operations in a continued fracti...
Definition: FastExp.h:382
constexpr ExpressionChainBuilder< expressions::FastExp2 > fastExp2
A fast approximation of the exp2 function (e.g.
Definition: FastExp.h:392
The main namespace of the VCTR project.
Definition: Array.h:24
typename detail::ValueType< std::remove_cvref_t< T > >::Type ValueType
If T is an expression template, it equals its return type, if it's a type that defines value_type as ...
Definition: Traits.h:201
Definition: AVXRegister.h:28
An expression chain builder is an object which supplies various operator<< overloads which build chai...
Definition: ExpressionChainBuilder.h:157
The base class to every expression template.
Definition: ExpressionTemplate.h:37
Definition: NeonRegister.h:28
Definition: SSERegister.h:28
Helper template to define a union of all supported SIMD types.
Definition: ExpressionTemplate.h:123