From 333527c8a66f53aa2b62123d814fb0cf2b9fc980 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 19 May 2026 16:35:05 +0200 Subject: [PATCH 1/9] Forward rules for vector interface --- ext/TensorKitMooncakeExt/vectorinterface.jl | 67 +++++++++++++++++++-- test/mooncake/vectorinterface.jl | 34 +++++------ 2 files changed, 79 insertions(+), 22 deletions(-) diff --git a/ext/TensorKitMooncakeExt/vectorinterface.jl b/ext/TensorKitMooncakeExt/vectorinterface.jl index a6f2db85f..4bbf09d12 100644 --- a/ext/TensorKitMooncakeExt/vectorinterface.jl +++ b/ext/TensorKitMooncakeExt/vectorinterface.jl @@ -1,4 +1,4 @@ -@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, Number} +@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, Number} function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) # prepare arguments @@ -19,7 +19,22 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTens return C_ΔC, scale_pullback end -@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} +function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractTensorMap}, α_Δα::Dual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + α, Δα = Mooncake.extract(α_Δα) + + if !isa(Δα, Mooncake.NoTangent) + add!(ΔC, C, Δα, α) + else + scale!(ΔC, α) + end + scale!(C, α) + + return C_ΔC +end + +@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) # prepare arguments @@ -42,7 +57,21 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTens return C_ΔC, scale_pullback end -@is_primitive DefaultCtx ReverseMode Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} +function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractTensorMap}, A_ΔA::Dual{<:AbstractTensorMap}, α_Δα::Dual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α, Δα = Mooncake.extract(α_Δα) + + scale!(ΔC, ΔA, α) + if !isa(Δα, Mooncake.NoTangent) + add!(ΔC, A, Δα, One()) + end + scale!(C, A, α) + return C_ΔC +end + +@is_primitive DefaultCtx Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) # prepare arguments @@ -69,7 +98,26 @@ function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensor return C_ΔC, add_pullback end -@is_primitive DefaultCtx ReverseMode Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} +function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual{<:AbstractTensorMap}, A_ΔA::Dual{<:AbstractTensorMap}, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α, Δα = Mooncake.extract(α_Δα) + β, Δβ = Mooncake.extract(β_Δβ) + add!(ΔC, ΔA, α, β) + if isa(Δβ, Mooncake.NoTangent) && !isa(Δα, Mooncake.NoTangent) + add!(ΔC, A, Δα, One()) + elseif isa(Δα, Mooncake.NoTangent) && !isa(Δβ, Mooncake.NoTangent) + add!(ΔC, C, Δβ, One()) + elseif !isa(Δα, Mooncake.NoTangent) && !isa(Δβ, Mooncake.NoTangent) + add!(ΔC, A, Δα, One()) + add!(ΔC, C, Δβ, One()) + end + add!(C, A, α, β) + return C_ΔC +end + +@is_primitive DefaultCtx Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap}) # prepare arguments @@ -87,3 +135,14 @@ function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractTenso return CoDual(s, NoFData()), inner_pullback end + +function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual{<:AbstractTensorMap}, B_ΔB::Dual{<:AbstractTensorMap}) + # prepare arguments + A, ΔA = arrayify(A_ΔA) + B, ΔB = arrayify(B_ΔB) + + s = inner(A, B) + Δs = inner(A, ΔB) + inner(ΔA, B) + + return Dual(s, Δs) +end diff --git a/test/mooncake/vectorinterface.jl b/test/mooncake/vectorinterface.jl index 5d10101da..7ffbf0735 100644 --- a/test/mooncake/vectorinterface.jl +++ b/test/mooncake/vectorinterface.jl @@ -3,9 +3,8 @@ using TensorKit using TensorOperations using Mooncake using Random +using VectorInterface - -mode = Mooncake.ReverseMode rng = Random.default_rng() spacelist = ad_spacelist(fast_tests) @@ -17,20 +16,19 @@ eltypes = (Float64, ComplexF64) C = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') - α = randn(T) - β = randn(T) - - Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, mode) - - Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, mode, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, mode, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, mode) - - Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, mode) + for α in (randn(T), One(), Zero()), β in (randn(T), One(), Zero()) + Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol) + + Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol) + + Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol) + Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol) + end end From 00b4c8c5e0b0f9b50a80b372750274a44b511876 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 07:10:58 +0200 Subject: [PATCH 2/9] Try falling back to new VI rules --- Project.toml | 3 + .../TensorKitMooncakeExt.jl | 1 - ext/TensorKitMooncakeExt/vectorinterface.jl | 148 ------------------ 3 files changed, 3 insertions(+), 149 deletions(-) delete mode 100644 ext/TensorKitMooncakeExt/vectorinterface.jl diff --git a/Project.toml b/Project.toml index 3d0abc9b1..b9bd4180f 100644 --- a/Project.toml +++ b/Project.toml @@ -37,6 +37,9 @@ TensorKitMooncakeExt = "Mooncake" [workspace] projects = ["test", "docs"] +[sources] +VectorInterface = {url = "https://github.com/Jutho/VectorInterface.jl", rev = "main"} + [compat] Adapt = "4" AMDGPU = "2" diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index 7c0492239..f50101c96 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -17,7 +17,6 @@ include("utility.jl") include("tangent.jl") include("linalg.jl") include("indexmanipulations.jl") -include("vectorinterface.jl") include("tensoroperations.jl") include("planaroperations.jl") include("factorizations.jl") diff --git a/ext/TensorKitMooncakeExt/vectorinterface.jl b/ext/TensorKitMooncakeExt/vectorinterface.jl deleted file mode 100644 index 4bbf09d12..000000000 --- a/ext/TensorKitMooncakeExt/vectorinterface.jl +++ /dev/null @@ -1,148 +0,0 @@ -@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, Number} - -function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - α = primal(α_Δα) - - # primal call - C_cache = copy(C) - scale!(C, α) - - function scale_pullback(::NoRData) - copy!(C, C_cache) - Δαr = _needs_tangent(α) ? project_scalar(α, inner(C, ΔC)) : NoRData() - scale!(ΔC, conj(α)) - return NoRData(), NoRData(), Δαr - end - - return C_ΔC, scale_pullback -end - -function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractTensorMap}, α_Δα::Dual{<:Number}) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - α, Δα = Mooncake.extract(α_Δα) - - if !isa(Δα, Mooncake.NoTangent) - add!(ΔC, C, Δα, α) - else - scale!(ΔC, α) - end - scale!(C, α) - - return C_ΔC -end - -@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} - -function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - A, ΔA = arrayify(A_ΔA) - α = primal(α_Δα) - - # primal call - C_cache = copy(C) - scale!(C, A, α) - - function scale_pullback(::NoRData) - copy!(C, C_cache) - add!(ΔA, ΔC, conj(α)) - Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData() - zerovector!(ΔC) - return NoRData(), NoRData(), NoRData(), Δαr - end - - return C_ΔC, scale_pullback -end - -function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractTensorMap}, A_ΔA::Dual{<:AbstractTensorMap}, α_Δα::Dual{<:Number}) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - A, ΔA = arrayify(A_ΔA) - α, Δα = Mooncake.extract(α_Δα) - - scale!(ΔC, ΔA, α) - if !isa(Δα, Mooncake.NoTangent) - add!(ΔC, A, Δα, One()) - end - scale!(C, A, α) - return C_ΔC -end - -@is_primitive DefaultCtx Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} - -function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - A, ΔA = arrayify(A_ΔA) - α = primal(α_Δα) - β = primal(β_Δβ) - - # primal call - C_cache = copy(C) - add!(C, A, α, β) - - function add_pullback(::NoRData) - copy!(C, C_cache) - - Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData() - Δβr = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : NoRData() - add!(ΔA, ΔC, conj(α)) - scale!(ΔC, conj(β)) - - return NoRData(), NoRData(), NoRData(), Δαr, Δβr - end - - return C_ΔC, add_pullback -end - -function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual{<:AbstractTensorMap}, A_ΔA::Dual{<:AbstractTensorMap}, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - A, ΔA = arrayify(A_ΔA) - α, Δα = Mooncake.extract(α_Δα) - β, Δβ = Mooncake.extract(β_Δβ) - add!(ΔC, ΔA, α, β) - if isa(Δβ, Mooncake.NoTangent) && !isa(Δα, Mooncake.NoTangent) - add!(ΔC, A, Δα, One()) - elseif isa(Δα, Mooncake.NoTangent) && !isa(Δβ, Mooncake.NoTangent) - add!(ΔC, C, Δβ, One()) - elseif !isa(Δα, Mooncake.NoTangent) && !isa(Δβ, Mooncake.NoTangent) - add!(ΔC, A, Δα, One()) - add!(ΔC, C, Δβ, One()) - end - add!(C, A, α, β) - return C_ΔC -end - -@is_primitive DefaultCtx Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} - -function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap}) - # prepare arguments - A, ΔA = arrayify(A_ΔA) - B, ΔB = arrayify(B_ΔB) - - # primal call - s = inner(A, B) - - function inner_pullback(Δs) - add!(ΔA, B, conj(Δs)) - add!(ΔB, A, Δs) - return NoRData(), NoRData(), NoRData() - end - - return CoDual(s, NoFData()), inner_pullback -end - -function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual{<:AbstractTensorMap}, B_ΔB::Dual{<:AbstractTensorMap}) - # prepare arguments - A, ΔA = arrayify(A_ΔA) - B, ΔB = arrayify(B_ΔB) - - s = inner(A, B) - Δs = inner(A, ΔB) + inner(ΔA, B) - - return Dual(s, Δs) -end From f00482098bfc680cf5ca39379a2e362aee1cbd87 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 11:25:51 +0200 Subject: [PATCH 3/9] Fix tangent typo --- ext/TensorKitMooncakeExt/tangent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index 151244b94..ac63fe8c3 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -191,7 +191,7 @@ _field_symbol(t, ::Val{F}) where {F} = _field_symbol(t, F) # frules _frule_getfield_common(t_dt::Dual{<:DiagOrTensorMap}, field_sym::Symbol) = - Dual(getfield(primal(t), field_sym), field_sym === :data ? tangent(t).data : NoFData()) + Dual(getfield(primal(t_dt), field_sym), field_sym === :data ? tangent(t_dt).data : NoFData()) Mooncake.frule!!(::Dual{typeof(Mooncake.lgetfield)}, t_dt::Dual{<:DiagOrTensorMap}, f_df::Dual) = _frule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df))) From 5f7c7717555736808c50e0609399258f9ba0d9ea Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 14:49:25 +0200 Subject: [PATCH 4/9] Some small fixes --- Project.toml | 2 +- ext/TensorKitMooncakeExt/tangent.jl | 2 +- ext/TensorKitMooncakeExt/utility.jl | 1 + test/mooncake/vectorinterface.jl | 21 ++++++++++----------- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index b9bd4180f..e71934cca 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ TensorKitMooncakeExt = "Mooncake" projects = ["test", "docs"] [sources] -VectorInterface = {url = "https://github.com/Jutho/VectorInterface.jl", rev = "main"} +VectorInterface = {url = "https://github.com/kshyatt/VectorInterface.jl", rev = "ksh/mooncake_loosen"} [compat] Adapt = "4" diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index ac63fe8c3..776591248 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -191,7 +191,7 @@ _field_symbol(t, ::Val{F}) where {F} = _field_symbol(t, F) # frules _frule_getfield_common(t_dt::Dual{<:DiagOrTensorMap}, field_sym::Symbol) = - Dual(getfield(primal(t_dt), field_sym), field_sym === :data ? tangent(t_dt).data : NoFData()) + Dual(getfield(primal(t_dt), field_sym), field_sym === :data ? tangent(t_dt).data : NoTangent()) Mooncake.frule!!(::Dual{typeof(Mooncake.lgetfield)}, t_dt::Dual{<:DiagOrTensorMap}, f_df::Dual) = _frule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df))) diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index ceb32d867..64ad6520d 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -67,6 +67,7 @@ Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent @zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.adjoint), HomSpace} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.permute), HomSpace, Index2Tuple} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.braid), HomSpace, Index2Tuple, IndexTuple} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.compose), HomSpace, HomSpace} diff --git a/test/mooncake/vectorinterface.jl b/test/mooncake/vectorinterface.jl index 7ffbf0735..900acbae6 100644 --- a/test/mooncake/vectorinterface.jl +++ b/test/mooncake/vectorinterface.jl @@ -1,9 +1,8 @@ using Test, TestExtras using TensorKit using TensorOperations -using Mooncake +using VectorInterface, Mooncake using Random -using VectorInterface rng = Random.default_rng() @@ -17,18 +16,18 @@ eltypes = (Float64, ComplexF64) C = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') for α in (randn(T), One(), Zero()), β in (randn(T), One(), Zero()) - Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol) - Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol) - Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol) - Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol) - Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol) - Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, is_primitive = false) Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, is_primitive = false) Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol) - Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol) + Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, is_primitive = false) end end From 2370983bfb8c0d5ad9de650b1fcda1e50d1ee1ba Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 15:42:39 +0200 Subject: [PATCH 5/9] Restore primitive markers --- ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl | 1 + ext/TensorKitMooncakeExt/vectorinterface.jl | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 ext/TensorKitMooncakeExt/vectorinterface.jl diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index f50101c96..7c0492239 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -17,6 +17,7 @@ include("utility.jl") include("tangent.jl") include("linalg.jl") include("indexmanipulations.jl") +include("vectorinterface.jl") include("tensoroperations.jl") include("planaroperations.jl") include("factorizations.jl") diff --git a/ext/TensorKitMooncakeExt/vectorinterface.jl b/ext/TensorKitMooncakeExt/vectorinterface.jl new file mode 100644 index 000000000..260f32a01 --- /dev/null +++ b/ext/TensorKitMooncakeExt/vectorinterface.jl @@ -0,0 +1,4 @@ +@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, Number} +@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} +@is_primitive DefaultCtx Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} +@is_primitive DefaultCtx Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} From bcbc1b3f285918e41a2c1000602fc7f33a4a1efe Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 18:22:54 +0200 Subject: [PATCH 6/9] Update sources --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e71934cca..b3923e8c9 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ TensorKitMooncakeExt = "Mooncake" projects = ["test", "docs"] [sources] -VectorInterface = {url = "https://github.com/kshyatt/VectorInterface.jl", rev = "ksh/mooncake_loosen"} +VectorInterface = {url = "https://github.com/QuantumKitHub/VectorInterface.jl", rev = "main"} [compat] Adapt = "4" From 61df722c89966b338e6f5c53f512712433bab5d7 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 1 Jun 2026 09:35:15 +0200 Subject: [PATCH 7/9] Test stuff is primitive when it should be --- test/mooncake/vectorinterface.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/mooncake/vectorinterface.jl b/test/mooncake/vectorinterface.jl index 900acbae6..e6b910610 100644 --- a/test/mooncake/vectorinterface.jl +++ b/test/mooncake/vectorinterface.jl @@ -16,18 +16,18 @@ eltypes = (Float64, ComplexF64) C = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') for α in (randn(T), One(), Zero()), β in (randn(T), One(), Zero()) - Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol) Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, is_primitive = false) Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol) - Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol) + Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol) end end From 2f79c694c22c1017a720d7155a51bdaac1176220 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 1 Jun 2026 09:40:12 +0200 Subject: [PATCH 8/9] Fix Project.toml now that VI is tagged --- Project.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/Project.toml b/Project.toml index b3923e8c9..3d0abc9b1 100644 --- a/Project.toml +++ b/Project.toml @@ -37,9 +37,6 @@ TensorKitMooncakeExt = "Mooncake" [workspace] projects = ["test", "docs"] -[sources] -VectorInterface = {url = "https://github.com/QuantumKitHub/VectorInterface.jl", rev = "main"} - [compat] Adapt = "4" AMDGPU = "2" From cc36f2604a147bb7edc1bb73e5f60b943cb6db0e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 16 Jun 2026 09:26:02 +0200 Subject: [PATCH 9/9] Add some adjoint tests for add --- test/mooncake/vectorinterface.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/mooncake/vectorinterface.jl b/test/mooncake/vectorinterface.jl index e6b910610..71fce2358 100644 --- a/test/mooncake/vectorinterface.jl +++ b/test/mooncake/vectorinterface.jl @@ -26,6 +26,9 @@ eltypes = (Float64, ComplexF64) Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, is_primitive = false) Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, is_primitive = false) Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol) + Mooncake.TestUtils.test_rule(rng, add!, C', A'; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, copy(C'), A', α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C', copy(A'), α, β; atol, rtol) Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol) Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol)