Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions NAM/wavenet/slimmable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,45 +291,45 @@ bool is_full_size(const std::vector<wavenet::LayerArrayParams>& params, const st

} // anonymous namespace

#ifdef _LIBCPP_VERSION
#if NAM_HAS_ATOMIC_SHARED_PTR
void SlimmableWavenet::_pending_clear_release()
{
std::atomic_store_explicit(&_pending_staged, std::shared_ptr<StagedSlimModel>{}, std::memory_order_release);
_pending_staged.store({}, std::memory_order_release);
}

std::shared_ptr<SlimmableWavenet::StagedSlimModel> SlimmableWavenet::_pending_load_acquire() const
{
return std::atomic_load_explicit(&_pending_staged, std::memory_order_acquire);
return _pending_staged.load(std::memory_order_acquire);
}

void SlimmableWavenet::_pending_store_release(std::shared_ptr<StagedSlimModel> p)
{
std::atomic_store_explicit(&_pending_staged, std::move(p), std::memory_order_release);
_pending_staged.store(std::move(p), std::memory_order_release);
}

std::shared_ptr<SlimmableWavenet::StagedSlimModel> SlimmableWavenet::_pending_exchange_take_acq_rel()
{
return std::atomic_exchange_explicit(&_pending_staged, std::shared_ptr<StagedSlimModel>{}, std::memory_order_acq_rel);
return _pending_staged.exchange({}, std::memory_order_acq_rel);
}
#else
void SlimmableWavenet::_pending_clear_release()
{
_pending_staged.store({}, std::memory_order_release);
std::atomic_store_explicit(&_pending_staged, std::shared_ptr<StagedSlimModel>{}, std::memory_order_release);
}

std::shared_ptr<SlimmableWavenet::StagedSlimModel> SlimmableWavenet::_pending_load_acquire() const
{
return _pending_staged.load(std::memory_order_acquire);
return std::atomic_load_explicit(&_pending_staged, std::memory_order_acquire);
}

void SlimmableWavenet::_pending_store_release(std::shared_ptr<StagedSlimModel> p)
{
_pending_staged.store(std::move(p), std::memory_order_release);
std::atomic_store_explicit(&_pending_staged, std::move(p), std::memory_order_release);
}

std::shared_ptr<SlimmableWavenet::StagedSlimModel> SlimmableWavenet::_pending_exchange_take_acq_rel()
{
return _pending_staged.exchange({}, std::memory_order_acq_rel);
return std::atomic_exchange_explicit(&_pending_staged, std::shared_ptr<StagedSlimModel>{}, std::memory_order_acq_rel);
}
#endif

Expand Down
20 changes: 15 additions & 5 deletions NAM/wavenet/slimmable.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@
#include <memory>
#include <vector>

#ifdef _LIBCPP_VERSION
// libc++: std::atomic<std::shared_ptr<T>> is not viable; staging uses deprecated atomic_* free functions.
// std::atomic<std::shared_ptr<T>> requires C++20 library support (libstdc++ >= GCC 12).
// Where it is unavailable -- libc++ (any version so far) and older libstdc++ such as the
// one shipped with GCC 9 -- fall back to the deprecated std::atomic_* free-function
// overloads for shared_ptr, which provide the same acquire/release semantics. Keyed on the
// C++20 feature-test macro rather than on a specific standard library so the right path is
// chosen for every compiler.
#if defined(__cpp_lib_atomic_shared_ptr) && __cpp_lib_atomic_shared_ptr >= 201711L
#define NAM_HAS_ATOMIC_SHARED_PTR 1
#else
#define NAM_HAS_ATOMIC_SHARED_PTR 0
#endif

#if NAM_HAS_ATOMIC_SHARED_PTR
#include <atomic>
#endif

Expand Down Expand Up @@ -71,11 +81,11 @@ class SlimmableWavenet : public DSP, public SlimmableModel
std::shared_ptr<DSP> model;
std::vector<int> channels;
};
#ifdef _LIBCPP_VERSION
#if NAM_HAS_ATOMIC_SHARED_PTR
std::atomic<std::shared_ptr<StagedSlimModel>> _pending_staged;
#else
/// Staged model; synchronized via deprecated std::atomic_* overloads for shared_ptr only.
std::shared_ptr<StagedSlimModel> _pending_staged;
#else
std::atomic<std::shared_ptr<StagedSlimModel>> _pending_staged;
#endif

std::vector<int> _current_channels;
Expand Down
Loading