11module MPSKitAdaptExt
22
3- using TensorKit: space, spacetype
3+ using TensorKit: space, spacetype, similarstoragetype, scalartype
44using MPSKit
55using BlockTensorKit: nonzero_pairs
66using Adapt
3232
3333Adapt. adapt_structure (to, mpo:: MPO ) = MPO (map (adapt (to), mpo. O))
3434
35- function Adapt. adapt_structure (:: Type{TorA} , W:: MPSKit.JordanMPOTensor ) where {TorA <: Union{Number, DenseVector{<:Number}} }
35+ function Adapt. adapt_structure (to, W:: MPSKit.JordanMPOTensorMap )
36+ TorA = similarstoragetype (to, scalartype (W))
3637 TT = MPSKit. jordanmpotensortype (spacetype (W), TorA)
3738 W′ = TT (undef, space (W))
3839 ad = adapt (TorA)
@@ -49,26 +50,9 @@ function Adapt.adapt_structure(::Type{TorA}, W::MPSKit.JordanMPOTensor) where {T
4950 for (k, v) in nonzero_pairs (W. D)
5051 W′. D[k] = ad (v)
5152 end
52-
5353 return W′
5454end
5555Adapt. adapt_structure (to, mpo:: MPOHamiltonian ) = MPOHamiltonian (map (adapt (to), mpo. W))
56-
57- function Adapt. adapt_structure (to, x:: MPSKit.MPOHamiltonian{TO} ) where {TO}
58- terms′ = map (w -> adapt (to, w), x. W)
59- return MPSKit. MPOHamiltonian (terms′)
60- end
61-
62- function Adapt. adapt_structure (to, x:: MPSKit.PeriodicArray )
63- return MPSKit. PeriodicArray (map (x_ -> adapt (to, x_), x. data))
64- end
65-
66- function Adapt. adapt_structure (to, x:: MPSKit.InfiniteMPS{A, B} ) where {A, S, B <: MPSKit.MPSBondTensor{S} }
67- AL′ = adapt (to, x. AL)
68- AR′ = adapt (to, x. AR)
69- AC′ = adapt (to, x. AC)
70- C′ = adapt (to, x. C)
71- return MPSKit. InfiniteMPS (AL′, AR′, C′, AC′)
72- end
56+ Adapt. adapt_structure (to, x:: MPSKit.PeriodicArray ) = MPSKit. PeriodicArray (map (adapt (to), x. data))
7357
7458end
0 commit comments