mutex(fix): prevents a potential race condition where threads could

skip waiting for the manager to give them the mutex

The solution was to remove write access from threads to the
manager_took_mutex variable
This commit is contained in:
PowerUser64 2024-11-19 13:48:43 -08:00
parent 77861709f8
commit d9b9d82c79

View file

@ -1,3 +1,8 @@
/**
* @brief A mutex-like thread syncronization mechanism
* @author Blake North <blake.north@digipen.edu>
*/
#include <cstddef> // size_t
#include <cstring> // memset
#include <iostream> // cout
@ -26,7 +31,7 @@ struct thread_data {
size_t id;
bool *wants_mutex;
const bool *has_mutex;
bool *manager_took_mutex;
const bool *manager_took_mutex;
bool *is_finished;
void *data;
};
@ -50,13 +55,10 @@ void *thread_task_increment(const struct thread_data &thread) {
// mutex
{
// wait for the thread manager to take the mutex if it hasn't yet
while (!*thread.manager_took_mutex)
;
// tell the thread manager we want the mutex
(*thread.wants_mutex) = true;
// block until we have the mutex
while (!*thread.has_mutex)
while (!*thread.has_mutex && *thread.manager_took_mutex)
;
// enter the mutex
@ -68,8 +70,6 @@ void *thread_task_increment(const struct thread_data &thread) {
// tell the thread manager we don't need the mutex
(*thread.wants_mutex) = false;
// wait until the thread manager takes the mutex from us
*thread.manager_took_mutex = false;
// mutex exit
}
loops++;
@ -151,7 +151,9 @@ void do_threading(struct thread_group threads) {
// take the mutex from the thread that has it
threads.has_mutex[tid_has] = false;
threads.manager_took_mutex[tid_has] = true;
// give the mutex to the thread that wants it
threads.manager_took_mutex[tid_wants] = false;
threads.has_mutex[tid_wants] = true;
}
@ -166,11 +168,6 @@ void do_threading(struct thread_group threads) {
}
}
// tell all threads we're done taking the mutex from whoever had it
for (size_t tid = 0; tid < threads.total_threads; ++tid) {
threads.manager_took_mutex[tid] = true;
}
// find how many threads are done with the mutex
finished_threads = 0;
for (size_t tid = 0; tid < threads.total_threads; ++tid) {