diff --git a/ext/TensorOperationsBumperExt.jl b/ext/TensorOperationsBumperExt.jl index 12791765..b6713666 100644 --- a/ext/TensorOperationsBumperExt.jl +++ b/ext/TensorOperationsBumperExt.jl @@ -27,16 +27,14 @@ function TensorOperations._butensor(src, ex...) buf_sym = gensym("buffer") # TODO: there is no check for doubled tensor kwargs - newex = quote - $buf_sym = $(Expr(:call, GlobalRef(Bumper, :default_buffer))) - $( - Expr( - :macrocall, GlobalRef(TensorOperations, Symbol("@tensor")), - src, :(allocator = $buf_sym), ex... - ) + return Expr( + :block, + Expr(:(=), buf_sym, Expr(:call, GlobalRef(Bumper, :default_buffer))), + Expr( + :macrocall, GlobalRef(TensorOperations, Symbol("@tensor")), + src, :(allocator = $buf_sym), ex... ) - end - return Base.remove_linenums!(newex) + ) end end diff --git a/src/indexnotation/contractiontrees.jl b/src/indexnotation/contractiontrees.jl index 77c4ccb8..e6af1a6c 100644 --- a/src/indexnotation/contractiontrees.jl +++ b/src/indexnotation/contractiontrees.jl @@ -163,7 +163,7 @@ function insertcontractiontrees!( end ) end - push!(postexprs, removelinenumbernode(costcompareex)) + push!(postexprs, removeinternallinenumbernodes(costcompareex)) return treeex end diff --git a/src/indexnotation/parser.jl b/src/indexnotation/parser.jl index 6bdd55ca..e8505564 100644 --- a/src/indexnotation/parser.jl +++ b/src/indexnotation/parser.jl @@ -12,7 +12,7 @@ mutable struct TensorParser contractiontreebuilder = defaulttreebuilder contractiontreesorter = defaulttreesorter contractioncostcheck = nothing - postprocessors = [_flatten, removelinenumbernode, addtensoroperations] + postprocessors = [_flatten, addtensoroperations] return new( preprocessors, contractiontreebuilder, contractiontreesorter, contractioncostcheck, @@ -34,6 +34,7 @@ function (parser::TensorParser)(ex::Expr) for p in parser.postprocessors ex = p(ex)::Expr end + ex = removeinternallinenumbernodes(ex)::Expr return ex end diff --git a/src/indexnotation/postprocessors.jl b/src/indexnotation/postprocessors.jl index 25b4297a..21bc613b 100644 --- a/src/indexnotation/postprocessors.jl +++ b/src/indexnotation/postprocessors.jl @@ -30,10 +30,45 @@ function _flatten(ex) end end +# package source directory (with trailing separator), used to recognize `LineNumberNode`s that +# point into the parser's own `quote` blocks rather than into user code. +const _PARSER_SRCDIR = joinpath(dirname(@__DIR__), "") + +_isinternallinenumber(@nospecialize(x)) = + x isa LineNumberNode && startswith(String(x.file), _PARSER_SRCDIR) + +""" + removeinternallinenumbernodes(ex) + +Remove all `LineNumberNode`s that point into the TensorOperations source tree, i.e. the ones +introduced by the parser's own `quote` blocks. `LineNumberNode`s originating from user code are +kept, so that the generated code remains attributable to the user's source lines (e.g. for code +coverage). +""" +function removeinternallinenumbernodes(ex) + if isexpr(ex, :block) + # within a block, `LineNumberNode`s are statement markers: drop the internal ones + args = Any[removeinternallinenumbernodes(e) for e in ex.args + if !_isinternallinenumber(e)] + return Expr(:block, args...) + elseif isa(ex, Expr) + # elsewhere (e.g. the mandatory 2nd argument of a `:macrocall`) a `LineNumberNode` may + # be structurally required, so keep all positions and only recurse into nested blocks + return Expr(ex.head, Any[removeinternallinenumbernodes(e) for e in ex.args]...) + else + return ex + end +end + """ removelinenumbernode(ex) Remove all `LineNumberNode`s from an expression. + +!!! note + Kept for backwards compatibility. The parser now uses + [`removeinternallinenumbernodes`](@ref), which preserves user `LineNumberNode`s so that + generated code stays attributable to the user's source lines (e.g. for code coverage). """ function removelinenumbernode(ex) if isexpr(ex, :block) diff --git a/test/butensor.jl b/test/butensor.jl index fefcb211..b0e4d501 100644 --- a/test/butensor.jl +++ b/test/butensor.jl @@ -6,6 +6,21 @@ end using Bumper +@testset "@butensor preserves user line numbers (issue #280)" begin + # `@butensor` wraps the block in an inner `@tensor`; make sure it does not strip the user's + # line numbers, and does not leak TensorOperations-internal ones. + pkgsrc = dirname(pathof(TensorOperations)) + lnns = LineNumberNode[] + collect_lnns(x) = x isa LineNumberNode ? push!(lnns, x) : + x isa Expr && foreach(collect_lnns, x.args) + collect_lnns(@macroexpand @butensor begin + T[a, b] := X[a, c] * Y[c, b] + Z[a, b] := T[a, c] * W[c, b] + end) + @test !any(l -> startswith(String(l.file), pkgsrc), lnns) + @test count(l -> String(l.file) == @__FILE__, lnns) >= 2 +end + @testset "Bumper tests with eltype $T" for T in (Float32, ComplexF64) D1, D2, D3 = 30, 40, 20 d1, d2 = 2, 3 diff --git a/test/macro_kwargs.jl b/test/macro_kwargs.jl index db577e49..621bc059 100644 --- a/test/macro_kwargs.jl +++ b/test/macro_kwargs.jl @@ -92,6 +92,54 @@ end end end +# https://github.com/QuantumKitHub/TensorOperations.jl/issues/280: the generated code must keep +# the user's `LineNumberNode`s (so `@tensor` lines show up in code coverage) while dropping the +# parser's own internal ones (which would otherwise pollute the package's coverage). +@testset "line numbers (issue #280)" begin + collectlinenumbernodes(ex, acc = LineNumberNode[]) = + (ex isa LineNumberNode ? push!(acc, ex) : + ex isa Expr && foreach(e -> collectlinenumbernodes(e, acc), ex.args); acc) + pkgsrc = dirname(pathof(TensorOperations)) + pkglnns(lnns) = filter(l -> startswith(String(l.file), pkgsrc), lnns) + userlines(lnns) = sort!(unique!([l.line for l in lnns if String(l.file) == @__FILE__])) + + @testset "no internal LineNumberNodes leak into generated code" begin + # covers the scalar, dst-reuse and checkpoint `quote` paths in the parser + exprs = [ + @macroexpand(@tensor T[a, b] := A[a, c] * B[c, b]), + @macroexpand(@tensor R[a, b] := A[a, c] * B[c, d] * C[d, e] * E[e, f] * F[f, b]), + @macroexpand(@tensor s = X[a, b] * Y[a, b]), + @macroexpand(@tensoropt R[a, b] := A[a, c] * B[c, d] * C[d, e] * E[e, b]), + @macroexpand(@tensor allocator = alloc R[a, b] := A[a, c] * B[c, d] * C[d, b]), + @macroexpand(@tensor costcheck = warn R[a, b] := A[a, c] * B[c, d] * C[d, b]), + @macroexpand(@tensor contractcheck = true R[a, b] := A[a, c] * B[c, b]), + ] + for ex in exprs + @test isempty(pkglnns(collectlinenumbernodes(ex))) + end + end + + @testset "user LineNumberNodes are preserved per statement" begin + # multi-statement block, including a nested contraction whose intermediate is reused + block = @macroexpand @tensor begin + T[a, e] := A[a, c] * B[c, d] * C[d, e] + D[a, b] := T[a, e] * E[e, b] + s = D[a, b] * F[a, b] + end + lnns = collectlinenumbernodes(block) + @test isempty(pkglnns(lnns)) + @test length(userlines(lnns)) >= 3 # one distinct user line per statement + + optblock = @macroexpand @tensoropt begin + T[a, e] := A[a, c] * B[c, d] * C[d, e] + D[a, b] := T[a, e] * E[e, b] + end + optlnns = collectlinenumbernodes(optblock) + @test isempty(pkglnns(optlnns)) + @test length(userlines(optlnns)) >= 2 + end +end + @testset "opt" begin A = randn(5, 5, 5, 5) B = randn(5, 5, 5)