@@ -7,6 +7,7 @@ mod math {
77 function:: { ArgIntoFloat , ArgIterable , Either , OptionalArg , PosArgs } ,
88 identifier, PyObject , PyObjectRef , PyRef , PyResult , VirtualMachine ,
99 } ;
10+ use itertools:: Itertools ;
1011 use num_bigint:: BigInt ;
1112 use num_rational:: Ratio ;
1213 use num_traits:: { One , Signed , ToPrimitive , Zero } ;
@@ -283,24 +284,73 @@ mod math {
283284 if has_nan {
284285 return f64:: NAN ;
285286 }
286- vector_norm ( & coordinates, max)
287+ coordinates. sort_unstable_by ( |x, y| x. total_cmp ( y) . reverse ( ) ) ;
288+ vector_norm ( & coordinates)
287289 }
288290
289- fn vector_norm ( v : & [ f64 ] , max : f64 ) -> f64 {
290- if max == 0.0 || v. len ( ) <= 1 {
291+ /// Implementation of accurate hypotenuse algorithm from Borges 2019.
292+ /// See https://arxiv.org/abs/1904.09481.
293+ /// This assumes that its arguments are positive finite and have been scaled to avoid overflow
294+ /// and underflow.
295+ fn accurate_hypot ( max : f64 , min : f64 ) -> f64 {
296+ if min <= max * ( f64:: EPSILON / 2.0 ) . sqrt ( ) {
291297 return max;
292298 }
293- let mut csum = 1.0 ;
294- let mut frac = 0.0 ;
295- for & f in v {
296- let f = f / max;
297- let f = f * f;
298- let old = csum;
299- csum += f;
300- // this seemingly redundant operation is to reduce float rounding errors/inaccuracy
301- frac += ( old - csum) + f;
302- }
303- max * f64:: sqrt ( csum - 1.0 + frac)
299+ let hypot = max. mul_add ( max, min * min) . sqrt ( ) ;
300+ let hypot_sq = hypot * hypot;
301+ let max_sq = max * max;
302+ let correction = ( -min) . mul_add ( min, hypot_sq - max_sq) + hypot. mul_add ( hypot, -hypot_sq)
303+ - max. mul_add ( max, -max_sq) ;
304+ hypot - correction / ( 2.0 * hypot)
305+ }
306+
307+ /// Calculates the norm of the vector given by `v`.
308+ /// `v` is assumed to be a list of non-negative finite floats, sorted in descending order.
309+ fn vector_norm ( v : & [ f64 ] ) -> f64 {
310+ // Drop zeros from the vector.
311+ let zero_count = v. iter ( ) . rev ( ) . cloned ( ) . take_while ( |x| * x == 0.0 ) . count ( ) ;
312+ let v = & v[ ..v. len ( ) - zero_count] ;
313+ if v. is_empty ( ) {
314+ return 0.0 ;
315+ }
316+ if v. len ( ) == 1 {
317+ return v[ 0 ] ;
318+ }
319+ // Calculate scaling to avoid overflow / underflow.
320+ let max = * v. first ( ) . unwrap ( ) ;
321+ let min = * v. last ( ) . unwrap ( ) ;
322+ let scale = if max > ( f64:: MAX / v. len ( ) as f64 ) . sqrt ( ) {
323+ max
324+ } else if min < f64:: MIN_POSITIVE . sqrt ( ) {
325+ // ^ This can be an `else if`, because if the max is near f64::MAX and the min is near
326+ // f64::MIN_POSITIVE, then the min is relatively unimportant and will be effectively
327+ // ignored.
328+ min
329+ } else {
330+ 1.0
331+ } ;
332+ let mut norm = v
333+ . iter ( )
334+ . copied ( )
335+ . map ( |x| x / scale)
336+ . reduce ( accurate_hypot)
337+ . unwrap_or_default ( ) ;
338+ if v. len ( ) > 2 {
339+ // For larger lists of numbers, we can accumulate a rounding error, so a correction is
340+ // needed, similar to that in `accurate_hypot()`.
341+ // First, we estimate [sum of squares - norm^2], then we add the first-order
342+ // approximation of the square root of that to `norm`.
343+ let correction = v
344+ . iter ( )
345+ . copied ( )
346+ . map ( |x| ( x / scale) . powi ( 2 ) )
347+ . chain ( std:: iter:: once ( -norm * norm) )
348+ // Pairwise summation of floats gives less rounding error than a naive sum.
349+ . tree_fold1 ( std:: ops:: Add :: add)
350+ . expect ( "expected at least 1 element" ) ;
351+ norm = norm + correction / ( 2.0 * norm) ;
352+ }
353+ norm * scale
304354 }
305355
306356 #[ pyfunction]
@@ -339,7 +389,8 @@ mod math {
339389 if has_nan {
340390 return Ok ( f64:: NAN ) ;
341391 }
342- Ok ( vector_norm ( & diffs, max) )
392+ diffs. sort_unstable_by ( |x, y| x. total_cmp ( y) . reverse ( ) ) ;
393+ Ok ( vector_norm ( & diffs) )
343394 }
344395
345396 #[ pyfunction]
0 commit comments