Skip to content

Commit 0dee6e5

Browse files
committed
Fix Adapt ext
1 parent d4e1a03 commit 0dee6e5

1 file changed

Lines changed: 4 additions & 20 deletions

File tree

ext/MPSKitAdaptExt.jl

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module MPSKitAdaptExt
22

3-
using TensorKit: space, spacetype
3+
using TensorKit: space, spacetype, similarstoragetype, scalartype
44
using MPSKit
55
using BlockTensorKit: nonzero_pairs
66
using Adapt
@@ -32,7 +32,8 @@ end
3232

3333
Adapt.adapt_structure(to, mpo::MPO) = MPO(map(adapt(to), mpo.O))
3434

35-
function Adapt.adapt_structure(::Type{TorA}, W::MPSKit.JordanMPOTensor) where {TorA <: Union{Number, DenseVector{<:Number}}}
35+
function Adapt.adapt_structure(to, W::MPSKit.JordanMPOTensorMap)
36+
TorA = similarstoragetype(to, scalartype(W))
3637
TT = MPSKit.jordanmpotensortype(spacetype(W), TorA)
3738
W′ = TT(undef, space(W))
3839
ad = adapt(TorA)
@@ -49,26 +50,9 @@ function Adapt.adapt_structure(::Type{TorA}, W::MPSKit.JordanMPOTensor) where {T
4950
for (k, v) in nonzero_pairs(W.D)
5051
W′.D[k] = ad(v)
5152
end
52-
5353
return W′
5454
end
5555
Adapt.adapt_structure(to, mpo::MPOHamiltonian) = MPOHamiltonian(map(adapt(to), mpo.W))
56-
57-
function Adapt.adapt_structure(to, x::MPSKit.MPOHamiltonian{TO}) where {TO}
58-
terms′ = map(w -> adapt(to, w), x.W)
59-
return MPSKit.MPOHamiltonian(terms′)
60-
end
61-
62-
function Adapt.adapt_structure(to, x::MPSKit.PeriodicArray)
63-
return MPSKit.PeriodicArray(map(x_ -> adapt(to, x_), x.data))
64-
end
65-
66-
function Adapt.adapt_structure(to, x::MPSKit.InfiniteMPS{A, B}) where {A, S, B <: MPSKit.MPSBondTensor{S}}
67-
AL′ = adapt(to, x.AL)
68-
AR′ = adapt(to, x.AR)
69-
AC′ = adapt(to, x.AC)
70-
C′ = adapt(to, x.C)
71-
return MPSKit.InfiniteMPS(AL′, AR′, C′, AC′)
72-
end
56+
Adapt.adapt_structure(to, x::MPSKit.PeriodicArray) = MPSKit.PeriodicArray(map(adapt(to), x.data))
7357

7458
end

0 commit comments

Comments
 (0)