diff --git a/src/common/ChannelChatters.cpp b/src/common/ChannelChatters.cpp index c69144d4e..c3ae07235 100644 --- a/src/common/ChannelChatters.cpp +++ b/src/common/ChannelChatters.cpp @@ -11,7 +11,7 @@ ChannelChatters::ChannelChatters(Channel &channel) { } -AccessGuard ChannelChatters::accessChatters() const +SharedAccessGuard ChannelChatters::accessChatters() const { return this->chatters_.accessConst(); } diff --git a/src/common/ChannelChatters.hpp b/src/common/ChannelChatters.hpp index c873f8559..f45cac8a8 100644 --- a/src/common/ChannelChatters.hpp +++ b/src/common/ChannelChatters.hpp @@ -15,7 +15,7 @@ public: ChannelChatters(Channel &channel); virtual ~ChannelChatters() = default; // add vtable - AccessGuard accessChatters() const; + SharedAccessGuard accessChatters() const; void addRecentChatter(const QString &user); void addJoinedUser(const QString &user); diff --git a/src/common/UniqueAccess.hpp b/src/common/UniqueAccess.hpp index 1b9685c2e..71d2525a3 100644 --- a/src/common/UniqueAccess.hpp +++ b/src/common/UniqueAccess.hpp @@ -1,39 +1,33 @@ #pragma once #include +#include #include namespace chatterino { -template +template > class AccessGuard { public: - AccessGuard(T &element, std::mutex &mutex) + AccessGuard(T &element, std::shared_mutex &mutex) : element_(&element) - , mutex_(&mutex) + , lock_(mutex) { - this->mutex_->lock(); } - AccessGuard(AccessGuard &&other) + AccessGuard(AccessGuard &&other) : element_(other.element_) - , mutex_(other.mutex_) + , lock_(std::move(other.lock_)) { - other.isValid_ = false; } - AccessGuard &operator=(AccessGuard &&other) + AccessGuard &operator=(AccessGuard &&other) { - other.isValid_ = false; this->element_ = other.element_; - this->mutex_ = other.element_; - } + this->lock_ = std::move(other.lock_); - ~AccessGuard() - { - if (this->isValid_) - this->mutex_->unlock(); + return *this; } T *operator->() const @@ -48,10 +42,13 @@ public: private: T *element_{}; - std::mutex *mutex_{}; - bool isValid_{true}; + LockType lock_; }; +template +using SharedAccessGuard = + AccessGuard>; + template class UniqueAccess { @@ -90,14 +87,14 @@ public: template ::value>> - AccessGuard accessConst() const + SharedAccessGuard accessConst() const { - return AccessGuard(this->element_, this->mutex_); + return SharedAccessGuard(this->element_, this->mutex_); } private: mutable T element_; - mutable std::mutex mutex_; + mutable std::shared_mutex mutex_; }; } // namespace chatterino diff --git a/src/providers/twitch/TwitchAccount.cpp b/src/providers/twitch/TwitchAccount.cpp index ff6d23859..909eafcd8 100644 --- a/src/providers/twitch/TwitchAccount.cpp +++ b/src/providers/twitch/TwitchAccount.cpp @@ -176,12 +176,14 @@ void TwitchAccount::checkFollow(const QString targetUserID, [] {}); } -AccessGuard> TwitchAccount::accessBlocks() const +SharedAccessGuard> TwitchAccount::accessBlocks() + const { return this->ignores_.accessConst(); } -AccessGuard> TwitchAccount::accessBlockedUserIds() const +SharedAccessGuard> TwitchAccount::accessBlockedUserIds() + const { return this->ignoresUserIds_.accessConst(); } @@ -345,7 +347,7 @@ void TwitchAccount::loadUserstateEmotes(QStringList emoteSetKeys) }); } -AccessGuard +SharedAccessGuard TwitchAccount::accessEmotes() const { return this->emotes_.accessConst(); diff --git a/src/providers/twitch/TwitchAccount.hpp b/src/providers/twitch/TwitchAccount.hpp index cde25ddc6..2ae80d0ee 100644 --- a/src/providers/twitch/TwitchAccount.hpp +++ b/src/providers/twitch/TwitchAccount.hpp @@ -106,12 +106,12 @@ public: void checkFollow(const QString targetUserID, std::function onFinished); - AccessGuard> accessBlockedUserIds() const; - AccessGuard> accessBlocks() const; + SharedAccessGuard> accessBlockedUserIds() const; + SharedAccessGuard> accessBlocks() const; void loadEmotes(); void loadUserstateEmotes(QStringList emoteSetKeys); - AccessGuard accessEmotes() const; + SharedAccessGuard accessEmotes() const; // Automod actions void autoModAllow(const QString msgID); diff --git a/src/providers/twitch/TwitchChannel.cpp b/src/providers/twitch/TwitchChannel.cpp index bd4be9a5e..4bb312adb 100644 --- a/src/providers/twitch/TwitchChannel.cpp +++ b/src/providers/twitch/TwitchChannel.cpp @@ -454,8 +454,8 @@ void TwitchChannel::setRoomId(const QString &id) } } -AccessGuard TwitchChannel::accessRoomModes() - const +SharedAccessGuard + TwitchChannel::accessRoomModes() const { return this->roomModes_.accessConst(); } @@ -472,7 +472,7 @@ bool TwitchChannel::isLive() const return this->streamStatus_.access()->live; } -AccessGuard +SharedAccessGuard TwitchChannel::accessStreamStatus() const { return this->streamStatus_.accessConst(); diff --git a/src/providers/twitch/TwitchChannel.hpp b/src/providers/twitch/TwitchChannel.hpp index 365a07467..06f26567d 100644 --- a/src/providers/twitch/TwitchChannel.hpp +++ b/src/providers/twitch/TwitchChannel.hpp @@ -83,8 +83,8 @@ public: int chatterCount(); virtual bool isLive() const override; QString roomId() const; - AccessGuard accessRoomModes() const; - AccessGuard accessStreamStatus() const; + SharedAccessGuard accessRoomModes() const; + SharedAccessGuard accessStreamStatus() const; // Emotes const TwitchBadges &globalTwitchBadges() const; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 2fb6ddadf..22580fccf 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -36,6 +36,7 @@ set(chatterino_SOURCES set(test_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/main.cpp + ${CMAKE_CURRENT_LIST_DIR}/src/AccessGuard.cpp ${CMAKE_CURRENT_LIST_DIR}/src/NetworkCommon.cpp ${CMAKE_CURRENT_LIST_DIR}/src/NetworkRequest.cpp ${CMAKE_CURRENT_LIST_DIR}/src/UsernameSet.cpp diff --git a/tests/src/AccessGuard.cpp b/tests/src/AccessGuard.cpp new file mode 100644 index 000000000..a0d1c6d31 --- /dev/null +++ b/tests/src/AccessGuard.cpp @@ -0,0 +1,76 @@ +#include "common/UniqueAccess.hpp" + +#include +#include +#include + +#include +#include +#include + +using namespace chatterino; + +using namespace std::chrono_literals; + +TEST(AccessGuardLocker, NonConcurrentUsage) +{ + std::shared_mutex m; + int e = 0; + + { + AccessGuard guard(e, m); + *guard = 3; + } + EXPECT_EQ(e, 3); + + { + AccessGuard guard(e, m); + *guard = 4; + } + EXPECT_EQ(e, 4); + + { + SharedAccessGuard guard(e, m); + EXPECT_EQ(*guard, 4); + } + EXPECT_EQ(e, 4); +} + +TEST(AccessGuardLocker, ConcurrentUsage) +{ + // This test doesn't actually prove anything on normal use, rather it needs to be run with AddressSanitizer/ThreadSanitizer and not error out for this to give any confidence + std::shared_mutex m; + int e = 0; + + auto startTime = std::chrono::steady_clock::now(); + + auto w = [&e, &m] { + std::mt19937_64 eng{std::random_device{}()}; + std::uniform_int_distribution<> dist{1, 4}; + std::this_thread::sleep_for(std::chrono::milliseconds{dist(eng)}); + if (rand() % 2 == 0) + { + AccessGuard guard(e, m); + std::this_thread::sleep_for(std::chrono::milliseconds{dist(eng)}); + *guard += 1; + } + else + { + SharedAccessGuard guard(e, m); + std::this_thread::sleep_for(std::chrono::milliseconds{dist(eng)}); + int hehe = *guard; + } + }; + + std::vector threads; + + for (int i = 0; i < 500; ++i) + { + threads.emplace_back(w); + } + + for (auto &t : threads) + { + t.join(); + } +}