Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,45 @@ end

args = parse_args(ARGS)

# Optional: spread test workers across all available GPUs, one worker per device
# (round-robin), by pinning each worker *process* to a device with ZE_AFFINITY_MASK.
# `device()` is task-local and Malt runs each test in a fresh task, so a `device!` in
# `init_worker_code` would not stick — pinning the process via the driver is the robust
# way to make every task on a worker use the same GPU.
#
# Enabled with ONEAPI_TEST_SPREAD_GPUS=1. When unset (the default) every worker stays on
# the first device, which oversubscribes a single tile — useful for surfacing
# contention/oversubscription bugs.
const spread_gpus = lowercase(get(ENV, "ONEAPI_TEST_SPREAD_GPUS", "")) in ("1", "true", "yes")
worker_env = Vector{Pair{String, String}}()
device_claim_code = :()
if spread_gpus
ndev = length(oneAPI.devices())
# shared, node-local directory used as an atomic round-robin counter (mkdir is atomic)
devdir = mktempdir(; prefix = "oneapi_test_gpus_")
push!(worker_env, "ONEAPI_TEST_DEVDIR" => devdir)
push!(worker_env, "ONEAPI_TEST_NDEV" => string(ndev))
@info "Spreading test workers across $ndev GPU(s) via ZE_AFFINITY_MASK (ONEAPI_TEST_SPREAD_GPUS=1)"
# NOTE: runs on the worker as the very first thing, before `using oneAPI` — so the
# Level Zero driver picks up ZE_AFFINITY_MASK at init and the process sees only its tile.
device_claim_code = quote
let dir = ENV["ONEAPI_TEST_DEVDIR"], ndev = parse(Int, ENV["ONEAPI_TEST_NDEV"])
i = 0
while true
try
mkdir(joinpath(dir, string(i)))
break
catch
i += 1
end
end
ENV["ZE_AFFINITY_MASK"] = string(i % ndev)
end
end
end

init_worker_code = quote
$device_claim_code
using oneAPI, Adapt

import GPUArrays
Expand Down Expand Up @@ -105,4 +143,4 @@ init_code = quote
..@grab_output, ..@on_device, ..sink
end

runtests(oneAPI, args; testsuite, init_code, init_worker_code)
runtests(oneAPI, args; testsuite, init_code, init_worker_code, env = worker_env)
Loading