Skip to content

Commit

Permalink
Add test_bases from QuantumOpticsBase and reorganize a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
akirakyle committed Dec 5, 2024
1 parent f22adb7 commit f17ee1e
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 135 deletions.
30 changes: 25 additions & 5 deletions src/QuantumInterface.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
module QuantumInterface

import Base: ==, +, -, *, /, ^, length, one, exp, conj, conj!, transpose, copy
import LinearAlgebra: tr, ishermitian, norm, normalize, normalize!
import Base: show, summary
import SparseArrays: sparse, spzeros, AbstractSparseMatrix # TODO move to an extension
##
# Basis specific
##

"""
basis(a)
Return the basis of an object.
If it's ambiguous, e.g. if an operator has a different left and right basis,
an [`IncompatibleBases`](@ref) error is thrown.
"""
function basis end

##
# Standard methods
##

function apply! end

function dagger end

"""
directsum(x, y, z...)
Direct sum of the given objects. Alternatively, the unicode
symbol ⊕ (\\oplus) can be used.
"""
function directsum end
const = directsum
directsum() = GenericBasis(0)
Expand Down Expand Up @@ -86,8 +105,9 @@ function squeeze end
function wigner end


include("bases.jl")
include("abstract_types.jl")
include("bases.jl")
include("show.jl")

include("linalg.jl")
include("tensor.jl")
Expand Down
32 changes: 15 additions & 17 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
"""
Abstract base class for all specialized bases.
The Basis class is meant to specify a basis of the Hilbert space of the
studied system. Besides basis specific information all subclasses must
implement a shape variable which indicates the dimension of the used
Hilbert space. For a spin-1/2 Hilbert space this would be the
vector `[2]`. A system composed of two spins would then have a
shape vector `[2 2]`.
Composite systems can be defined with help of the [`CompositeBasis`](@ref)
class.
"""
abstract type Basis end

"""
Abstract base class for `Bra` and `Ket` states.
Expand Down Expand Up @@ -38,20 +53,3 @@ A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)}
```
"""
abstract type AbstractSuperOperator{B1,B2} end

function summary(stream::IO, x::AbstractOperator)
print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n")
if samebases(x)
print(stream, " basis: ")
show(stream, basis(x))
else
print(stream, " basis left: ")
show(stream, x.basis_l)
print(stream, "\n basis right: ")
show(stream, x.basis_r)
end
end

show(stream::IO, x::AbstractOperator) = summary(stream, x)

traceout!(s::StateVector, i) = ptrace(s,i)
90 changes: 3 additions & 87 deletions src/bases.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,10 @@
"""
Abstract base class for all specialized bases.
The Basis class is meant to specify a basis of the Hilbert space of the
studied system. Besides basis specific information all subclasses must
implement a shape variable which indicates the dimension of the used
Hilbert space. For a spin-1/2 Hilbert space this would be the
vector `[2]`. A system composed of two spins would then have a
shape vector `[2 2]`.
Composite systems can be defined with help of the [`CompositeBasis`](@ref)
class.
"""
abstract type Basis end

"""
length(b::Basis)
Total dimension of the Hilbert space.
"""
Base.length(b::Basis) = prod(b.shape)

"""
basis(a)
Return the basis of an object.
If it's ambiguous, e.g. if an operator has a different left and right basis,
an [`IncompatibleBases`](@ref) error is thrown.
"""
function basis end


"""
GenericBasis(N)
Expand Down Expand Up @@ -366,9 +341,9 @@ SumBasis(shape, bases::Vector) = (tmp = (bases...,); SumBasis(shape, tmp))
SumBasis(bases::Vector) = SumBasis((bases...,))
SumBasis(bases::Basis...) = SumBasis((bases...,))

==(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape)
==(b1::SumBasis, b2::SumBasis) = false
length(b::SumBasis) = sum(b.shape)
Base.:(==)(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape)
Base.:(==)(b1::SumBasis, b2::SumBasis) = false
Base.length(b::SumBasis) = sum(b.shape)

Check warning on line 346 in src/bases.jl

View check run for this annotation

Codecov / codecov/patch

src/bases.jl#L344-L346

Added lines #L344 - L346 were not covered by tests

"""
directsum(b1::Basis, b2::Basis)
Expand All @@ -393,62 +368,3 @@ function directsum(b1::SumBasis, b2::SumBasis)
bases = [b1.bases...;b2.bases...]
return SumBasis(shape, (bases...,))
end

embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops)

##
# show methods
##

function show(stream::IO, x::GenericBasis)
if length(x.shape) == 1
write(stream, "Basis(dim=$(x.shape[1]))")
else
s = replace(string(x.shape), " " => "")
write(stream, "Basis(shape=$s)")
end
end

function show(stream::IO, x::CompositeBasis)
write(stream, "[")
for i in 1:length(x.bases)
show(stream, x.bases[i])
if i != length(x.bases)
write(stream, "")
end
end
write(stream, "]")
end

function show(stream::IO, x::SpinBasis)
d = denominator(x.spinnumber)
n = numerator(x.spinnumber)
if d == 1
write(stream, "Spin($n)")
else
write(stream, "Spin($n/$d)")
end
end

function show(stream::IO, x::FockBasis)
if iszero(x.offset)
write(stream, "Fock(cutoff=$(x.N))")
else
write(stream, "Fock(cutoff=$(x.N), offset=$(x.offset))")
end
end

function show(stream::IO, x::NLevelBasis)
write(stream, "NLevel(N=$(x.N))")
end

function show(stream::IO, x::SumBasis)
write(stream, "[")
for i in 1:length(x.bases)
show(stream, x.bases[i])
if i != length(x.bases)
write(stream, "")
end
end
write(stream, "]")
end
6 changes: 4 additions & 2 deletions src/embed_permute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
ops_sb = [x[2] for x in idxop_sb]

for (idxsb, opsb) in zip(indices_sb, ops_sb)
(opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases())
(opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases())
(opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases()) # FIXME issue #12
(opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases()) # FIXME issue #12

Check warning on line 71 in src/embed_permute.jl

View check run for this annotation

Codecov / codecov/patch

src/embed_permute.jl#L70-L71

Added lines #L70 - L71 were not covered by tests
end

S = length(operators) > 0 ? mapreduce(eltype, promote_type, operators) : Any
Expand All @@ -83,6 +83,8 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
return embed_op
end

embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops)

Check warning on line 86 in src/embed_permute.jl

View check run for this annotation

Codecov / codecov/patch

src/embed_permute.jl#L86

Added line #L86 was not covered by tests

permutesystems(a::AbstractOperator, perm) = arithmetic_unary_error("Permutations of subsystems", a)

nsubsystems(s::AbstractKet) = nsubsystems(basis(s))
Expand Down
4 changes: 2 additions & 2 deletions src/identityoperator.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x)
Base.one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x)

Check warning on line 1 in src/identityoperator.jl

View check run for this annotation

Codecov / codecov/patch

src/identityoperator.jl#L1

Added line #L1 was not covered by tests

"""
identityoperator(a::Basis[, b::Basis])
Expand All @@ -22,4 +22,4 @@ identityoperator(::Type{T}, ::Type{Any}, b1::Basis, b2::Basis) where T<:Abstract
identityoperator(b1::Basis, b2::Basis) = identityoperator(ComplexF64, b1, b2)

"""Prepare the identity superoperator over a given space."""
function identitysuperoperator end
function identitysuperoperator end
37 changes: 21 additions & 16 deletions src/julia_base.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Base: +, -, *, /, ^, length, exp, conj, conj!, adjoint, transpose, copy

# Common error messages
arithmetic_unary_error(funcname, x::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this type of operator: $(typeof(x)).\nTry to convert to another operator type first with e.g. dense() or sparse()."))
arithmetic_binary_error(funcname, a::AbstractOperator, b::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this combination of types of operators: $(typeof(a)), $(typeof(b)).\nTry to convert to a common operator type first with e.g. dense() or sparse()."))
Expand All @@ -8,33 +10,33 @@ addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an op
# States
##

-(a::T) where {T<:StateVector} = T(a.basis, -a.data)
-(a::T) where {T<:StateVector} = T(a.basis, -a.data) # FIXME issue #12

Check warning on line 13 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L13

Added line #L13 was not covered by tests
*(a::StateVector, b::Number) = b*a
copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data))
length(a::StateVector) = length(a.basis)::Int
basis(a::StateVector) = a.basis
copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) # FIXME issue #12
length(a::StateVector) = length(a.basis)::Int # FIXME issue #12
basis(a::StateVector) = a.basis # FIXME issue #12

Check warning on line 17 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L15-L17

Added lines #L15 - L17 were not covered by tests
directsum(x::StateVector...) = reduce(directsum, x)
adjoint(a::StateVector) = dagger(a)

Check warning on line 19 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L19

Added line #L19 was not covered by tests



# Array-like functions
Base.size(x::StateVector) = size(x.data)
@inline Base.axes(x::StateVector) = axes(x.data)
Base.size(x::StateVector) = size(x.data) # FIXME issue #12
@inline Base.axes(x::StateVector) = axes(x.data) # FIXME issue #12

Check warning on line 25 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L24-L25

Added lines #L24 - L25 were not covered by tests
Base.ndims(x::StateVector) = 1
Base.ndims(::Type{<:StateVector}) = 1
Base.eltype(x::StateVector) = eltype(x.data)
Base.eltype(x::StateVector) = eltype(x.data) # FIXME issue #12

Check warning on line 28 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L28

Added line #L28 was not covered by tests

# Broadcasting
Base.broadcastable(x::StateVector) = x

Base.adjoint(a::StateVector) = dagger(a)


##
# Operators
##

length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int
basis(a::AbstractOperator) = (check_samebases(a); a.basis_l)
basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1])
length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int # FIXME issue #12
basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) # FIXME issue #12
basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) # FIXME issue #12

Check warning on line 39 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L37-L39

Added lines #L37 - L39 were not covered by tests

# Ensure scalar broadcasting
Base.broadcastable(x::AbstractOperator) = Ref(x)
Expand All @@ -60,14 +62,17 @@ Operator exponential.
"""
exp(op::AbstractOperator) = throw(ArgumentError("exp() is not defined for this type of operator: $(typeof(op)).\nTry to convert to dense operator first with dense()."))

Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r))
Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r)) # FIXME issue #12

Check warning on line 65 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L65

Added line #L65 was not covered by tests
function Base.size(op::AbstractOperator, i::Int)
i < 1 && throw(ErrorException("dimension index is < 1"))
i > 2 && return 1
i==1 ? length(op.basis_l) : length(op.basis_r)
i==1 ? length(op.basis_l) : length(op.basis_r) # FIXME issue #12

Check warning on line 69 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L69

Added line #L69 was not covered by tests
end

Base.adjoint(a::AbstractOperator) = dagger(a)
adjoint(a::AbstractOperator) = dagger(a)

Check warning on line 72 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L72

Added line #L72 was not covered by tests

transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a)

Check warning on line 74 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L74

Added line #L74 was not covered by tests


conj(a::AbstractOperator) = arithmetic_unary_error("Complex conjugate", a)
conj!(a::AbstractOperator) = conj(a::AbstractOperator)
2 changes: 2 additions & 0 deletions src/julia_linalg.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import LinearAlgebra: tr, ishermitian, norm, normalize, normalize!

"""
ishermitian(op::AbstractOperator)
Expand Down
10 changes: 5 additions & 5 deletions src/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool
samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool
check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r)
multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l)
samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool # FIXME issue #12
samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool # FIXME issue #12
check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r) # FIXME issue #12
multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) # FIXME issue #12

Check warning on line 4 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L1-L4

Added lines #L1 - L4 were not covered by tests
dagger(a::AbstractOperator) = arithmetic_unary_error("Hermitian conjugate", a)
transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a)
directsum(a::AbstractOperator...) = reduce(directsum, a)
ptrace(a::AbstractOperator, index) = arithmetic_unary_error("Partial trace", a)
_index_complement(b::CompositeBasis, indices) = complement(length(b.bases), indices)
reduced(a, indices) = ptrace(a, _index_complement(basis(a), indices))
traceout!(s::StateVector, i) = ptrace(s,i)

Check warning on line 10 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L10

Added line #L10 was not covered by tests
69 changes: 69 additions & 0 deletions src/show.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import Base: show, summary

function summary(stream::IO, x::AbstractOperator)
print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n")
if samebases(x)
print(stream, " basis: ")
show(stream, basis(x))

Check warning on line 7 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L3-L7

Added lines #L3 - L7 were not covered by tests
else
print(stream, " basis left: ")
show(stream, x.basis_l)
print(stream, "\n basis right: ")
show(stream, x.basis_r)

Check warning on line 12 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L9-L12

Added lines #L9 - L12 were not covered by tests
end
end

show(stream::IO, x::AbstractOperator) = summary(stream, x)

Check warning on line 16 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L16

Added line #L16 was not covered by tests

function show(stream::IO, x::GenericBasis)
if length(x.shape) == 1
write(stream, "Basis(dim=$(x.shape[1]))")

Check warning on line 20 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L18-L20

Added lines #L18 - L20 were not covered by tests
else
s = replace(string(x.shape), " " => "")
write(stream, "Basis(shape=$s)")

Check warning on line 23 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L22-L23

Added lines #L22 - L23 were not covered by tests
end
end

function show(stream::IO, x::CompositeBasis)
write(stream, "[")
for i in 1:length(x.bases)
show(stream, x.bases[i])
if i != length(x.bases)
write(stream, "")

Check warning on line 32 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L27-L32

Added lines #L27 - L32 were not covered by tests
end
end
write(stream, "]")

Check warning on line 35 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
end

function show(stream::IO, x::SpinBasis)
d = denominator(x.spinnumber)
n = numerator(x.spinnumber)
if d == 1
write(stream, "Spin($n)")

Check warning on line 42 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L38-L42

Added lines #L38 - L42 were not covered by tests
else
write(stream, "Spin($n/$d)")

Check warning on line 44 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L44

Added line #L44 was not covered by tests
end
end

function show(stream::IO, x::FockBasis)
if iszero(x.offset)
write(stream, "Fock(cutoff=$(x.N))")

Check warning on line 50 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L48-L50

Added lines #L48 - L50 were not covered by tests
else
write(stream, "Fock(cutoff=$(x.N), offset=$(x.offset))")

Check warning on line 52 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L52

Added line #L52 was not covered by tests
end
end

function show(stream::IO, x::NLevelBasis)
write(stream, "NLevel(N=$(x.N))")

Check warning on line 57 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L56-L57

Added lines #L56 - L57 were not covered by tests
end

function show(stream::IO, x::SumBasis)
write(stream, "[")
for i in 1:length(x.bases)
show(stream, x.bases[i])
if i != length(x.bases)
write(stream, "")

Check warning on line 65 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L60-L65

Added lines #L60 - L65 were not covered by tests
end
end
write(stream, "]")

Check warning on line 68 in src/show.jl

View check run for this annotation

Codecov / codecov/patch

src/show.jl#L67-L68

Added lines #L67 - L68 were not covered by tests
end
Loading

0 comments on commit f17ee1e

Please sign in to comment.