-
Notifications
You must be signed in to change notification settings - Fork 5
Updates for TensorKit compatibility #49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report❌ Patch coverage is
🚀 New features to boost your workflow:
|
| include("yacusolver.jl") | ||
|
|
||
| function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix} | ||
| function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are probably a couple more of these somewhat complex wrapper types that can still be handled by these algorithms, how do you feel about doing something like
for MatType in [...]
@eval ...
endThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds fine to me, do we have a list of the ones we want?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this particular change also induced by TensorKit requirements, or simply more strictness (which I fully support)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be equivalent to defining a new type constant
const StridedCuBLASMatrix{T} = StridedCuMatrix{T} where {T<:BlasFloat}and then using default_xxx_algorithm(::Type{<:StridedCuBLASMatrix}; kwargs...) everywhere?
src/implementations/lq.jl
Outdated
| m, n = size(A) | ||
| minmn = min(m, n) | ||
| At = adjoint!(similar(A'), A)::AbstractMatrix | ||
| At = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this ::AbstractMatrix type assert useful?
|
Changed the format of some of the |
src/implementations/svd.jl
Outdated
| Ut = similar(U') | ||
| Vᴴt = similar(Vᴴ') | ||
| if size(U) == (m, m) | ||
| _gpu_gesvd!(At, view(S, 1:minmn, 1), Vᴴt, Ut) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't think long about it, but it was not immediately clear to me why this was necessary. Isn't S always of length minmn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I'm following what the CPU bindings have done, but I suppose we could be reusing an S over and over between differently sized arrays?
|
Your PR no longer requires formatting changes. Thank you for your contribution! |
1a9ace8 to
3ea43a3
Compare
lkdvos
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left two minor comments, and this PR needs a formatter run, otherwise this is good to merge!
|
OK fixed both comments, formatted. Thanks for the look! |
* Bump v0.6 * rename `gaugefix` -> `fixgauge` * reduce unnecessary warnings * fix `copy_input` signatures in Mooncake tests * Add changelog to docs
The
ReshapedArrayoverrides are needed to dispatch to the correct GPU algorithms. Needed to modify the type signature for the default algorithms to avoid ambiguities. Also it's nice to give some more info about dimension mismatches.