1- # Vector interface implementation for ' SparseArray
1+ # Vector interface implementation for SparseArray
22# #################################################################
33# zerovector & zerovector!!
44# ---------------------------
55function VectorInterface. zerovector (x:: SparseArray , :: Type{S} ) where {S<: Number }
6- T = typeof (zero (eltype (x)) * zero (S))
7- return SparseArray {T} (undef, size (x))
6+ return SparseArray {S} (undef, size (x))
87end
9-
108VectorInterface. zerovector! (x:: SparseArray ) = _zero! (x)
119VectorInterface. zerovector!! (x:: SparseArray ) = zerovector! (x)
1210
1311# scale, scale! & scale!!
1412# -------------------------
15- VectorInterface. scale (x:: SparseArray , α:: Number ) = _isone (α) ? copy (x) : x * α
16-
13+ function VectorInterface. scale (x:: SparseArray , α:: Number )
14+ α === One () && return copy (x)
15+ α === Zero () && return zerovector (x)
16+ return x * α
17+ end
1718function VectorInterface. scale! (x:: SparseArray , α:: Number )
18- _isone (α) && return x
19+ iszero (α) && return zerovector! (x)
1920 # typical occupation in a dict is about 30% from experimental testing
2021 # the benefits of scaling all values (e.g. SIMD) largely outweight the extra work
2122 scale! (x. data. vals, α)
@@ -25,24 +26,22 @@ function VectorInterface.scale!(y::SparseArray, x::SparseArray, α::Number)
2526 ax = axes (x)
2627 ay = axes (y)
2728 ax == ay || throw (DimensionMismatch (" output axes $ay differ from input axes $ax " ))
28- _zero ! (y)
29+ zerovector ! (y)
2930 for (k, v) in nonzero_pairs (x)
3031 y[k] = scale!! (v, α)
3132 end
3233 return y
3334end
34-
3535function VectorInterface. scale!! (x:: SparseArray , α:: Number )
36- T = scalartype (x)
37- if promote_type (T, typeof (α)) <: T
36+ α === One () && return x
37+ if VectorInterface . promote_scale (x, α) <: scalartype (x)
3838 return scale! (x, α)
3939 else
4040 return scale (x, α)
4141 end
4242end
4343function VectorInterface. scale!! (y:: SparseArray , x:: SparseArray , α:: Number )
44- T = scalartype (y)
45- if promote_type (T, typeof (α), scalartype (x)) <: T
44+ if VectorInterface. promote_scale (x, α) <: scalartype (y)
4645 return scale! (y, x, α)
4746 else
4847 return scale (x, α)
5150
5251# add, add! & add!!
5352# -------------------
54- function VectorInterface. add (y:: SparseArray ,
55- x:: SparseArray ,
56- α:: Number = _one,
57- β:: Number = _one)
53+ function VectorInterface. add (y:: SparseArray , x:: SparseArray , α:: Number , β:: Number )
5854 ax = axes (x)
5955 ay = axes (y)
6056 ax == ay || throw (DimensionMismatch (" output axes $ay differ from input axes $ax " ))
61- T = promote_type ( scalartype (y), scalartype (x), typeof (α), typeof (β) )
57+ T = VectorInterface . promote_add (y, x, α, β )
6258 z = SparseArray {T} (undef, size (y))
6359 scale! (z, y, β)
6460 add! (z, x, α)
6561 return z
6662end
6763
68- function VectorInterface. add! (y:: SparseArray ,
69- x:: SparseArray ,
70- α:: Number = _one,
71- β:: Number = _one)
64+ function VectorInterface. add! (y:: SparseArray , x:: SparseArray , α:: Number , β:: Number )
7265 ax = axes (x)
7366 ay = axes (y)
7467 ax == ay || throw (DimensionMismatch (" output axes $ay differ from input axes $ax " ))
75- _isone (β) || ( iszero (β) ? _zero! (y) : scale! (y, β) )
68+ scale! (y, β)
7669 for (k, v) in nonzero_pairs (x)
7770 increaseindex! (y, scale!! (v, α), k)
7871 end
7972 return y
8073end
8174
82- function VectorInterface. add!! (y:: SparseArray ,
83- x:: SparseArray ,
84- α:: Number = _one,
85- β:: Number = _one)
86- T = scalartype (y)
87- if promote_type (T, typeof (α), typeof (β), scalartype (x)) <: T
75+ function VectorInterface. add!! (y:: SparseArray , x:: SparseArray , α:: Number , β:: Number )
76+ if VectorInterface. promote_add (y, x, α, β) <: scalartype (y)
8877 return add! (y, x, α, β)
8978 else
9079 return add (y, x, α, β)
@@ -97,14 +86,14 @@ function VectorInterface.inner(x::SparseArray, y::SparseArray)
9786 ax = axes (x)
9887 ay = axes (y)
9988 ax == ay || throw (DimensionMismatch (" dot arguments have non-matching axes $ax and $ay " ))
100- s = dot ( zero (eltype (x)), zero ( eltype (y) ))
89+ s = zero (VectorInterface . promote_inner (x, y ))
10190 if nonzero_length (x) >= nonzero_length (y)
10291 @inbounds for I in nonzero_keys (x)
103- s += dot (x[I], y[I])
92+ s += VectorInterface . inner (x[I], y[I])
10493 end
10594 else
10695 @inbounds for I in nonzero_keys (y)
107- s += dot (x[I], y[I])
96+ s += VectorInterface . inner (x[I], y[I])
10897 end
10998 end
11099 return s
0 commit comments