mutex-cpp/main.cpp
PowerUser64 e9b4f8e389 mutex(fix): prevent race condition (part 2)
In addition to the previous commit, this commit makes threads need to
wait for the thread manager to say it took the mutex from them before
they can ask for it again.
2024-11-19 15:12:42 -08:00

203 lines
5.9 KiB
C++

/**
* @brief A mutex-like thread syncronization mechanism
* @author Blake North <blake.north@digipen.edu>
*/
#include <cstddef> // size_t
#include <cstring> // memset
#include <iostream> // cout
#include <vector> // vector
// used for threads
#include <pthread.h> // pthread
#include <unistd.h> // sleep() - testing
#define NUM_THREADS 8
typedef void *(*ThreadTask)(const struct thread_data &);
typedef void *(*PthreadFun)(void *arg);
struct thread_group {
bool wants_mutex[NUM_THREADS] = {0};
bool has_mutex[NUM_THREADS] = {0};
bool manager_took_mutex[NUM_THREADS] = {0};
bool threads_finished[NUM_THREADS] = {0};
ThreadTask task;
size_t total_threads;
void *data;
};
struct thread_data {
size_t id;
bool *wants_mutex;
const bool *has_mutex;
const bool *manager_took_mutex;
bool *is_finished;
void *data;
};
void *thread_task_increment(const struct thread_data &thread) {
// exists inside the loop
const size_t max_loops = 10;
size_t loops = 0;
std::cout << "Hello from thread " << thread.id << "! (before run loop)"
<< std::endl;
// thread loop, with an exit condition
while (!(*thread.is_finished)) {
// non-mutex operations
{
std::cout << "Hello from thread " << thread.id
<< "! (inside run loop, before mutex)" << std::endl;
}
// mutex
{
// tell the thread manager we want the mutex
while (!*thread.manager_took_mutex)
(*thread.wants_mutex) = true;
// block until we have the mutex
while (!*thread.has_mutex)
;
// enter the mutex
std::cout << "Hello from thread " << thread.id
<< "! (inside run loop, inside mutex)" << std::endl;
*static_cast<int *>(thread.data) += 1;
// tell the thread manager we don't need the mutex
(*thread.wants_mutex) = false;
// mutex exit
}
loops++;
if (max_loops < loops)
(*thread.is_finished) = true;
}
std::cout << "Hello from thread " << thread.id << "! (done)" << std::endl;
pthread_exit(nullptr);
}
void do_threading(struct thread_group threads) {
std::vector<struct thread_data> thread_data(threads.total_threads);
std::vector<pthread_t> my_pthreads(threads.total_threads);
// initialize (might already be done.) TODO: figure out if this is needed
// threads are responsible for telling us when they're blocked
// start of the mutex
memset(threads.wants_mutex, 0,
threads.total_threads * sizeof(*threads.wants_mutex));
// no threads are using the mutex to start with
memset(threads.has_mutex, 0,
threads.total_threads * sizeof(*threads.has_mutex));
// no threads are finished at the start
memset(threads.threads_finished, 0,
threads.total_threads * sizeof(*threads.threads_finished));
// at the start, we have the mutex and haven't given it
memset(threads.manager_took_mutex, 1,
threads.total_threads * sizeof(*threads.manager_took_mutex));
// create thread data
for (size_t tid = 0; tid < threads.total_threads; ++tid) {
struct thread_data this_thread_data = {
.id = tid,
.wants_mutex = &threads.wants_mutex[tid],
.has_mutex = &threads.has_mutex[tid],
.manager_took_mutex = &threads.manager_took_mutex[tid],
.is_finished = &threads.threads_finished[tid],
.data = threads.data,
};
thread_data[tid] = this_thread_data;
}
// spawn threads (none will enter the mutex yet)
for (size_t tid = 0; tid < threads.total_threads; ++tid) {
pthread_create(&my_pthreads[tid], NULL,
reinterpret_cast<PthreadFun>(thread_task_increment),
&thread_data[tid]);
}
std::cout << "Threads have been spawned." << std::endl;
// loop until all threads are done
for (size_t finished_threads = 0; finished_threads < threads.total_threads;) {
// TODO: make sure we cycle the mutex through threads round-robin style
// hand off the mutex to threads that want it
for (size_t tid_wants = 0; tid_wants < threads.total_threads; ++tid_wants) {
if (threads.wants_mutex[tid_wants]) {
// in case the mutex isn't used at all
bool mutex_was_found = false;
// find which thread has the mutex and hand it off if it's done
for (size_t tid_has = 0; tid_has < threads.total_threads; ++tid_has) {
if (threads.has_mutex[tid_has]) {
// we found who has the mutex!
mutex_was_found = true;
// is the thread still using the mutex?
if (!threads.wants_mutex[tid_has]) {
// take the mutex from the thread that has it
threads.has_mutex[tid_has] = false;
threads.manager_took_mutex[tid_has] = true;
// give the mutex to the thread that wants it
threads.manager_took_mutex[tid_wants] = false;
threads.has_mutex[tid_wants] = true;
}
break; // no need to look at the rest if we found who has the mutex
}
}
// give the thread the mutex if it wasn't found to be in use
if (!mutex_was_found) {
threads.has_mutex[tid_wants] = true;
}
}
}
// find how many threads are done with the mutex
finished_threads = 0;
for (size_t tid = 0; tid < threads.total_threads; ++tid) {
if (threads.threads_finished[tid])
finished_threads += 1;
}
}
// join all threads (just to make sure - they should already be done)
for (size_t tid = 0; tid < threads.total_threads; ++tid) {
pthread_join(my_pthreads[tid], nullptr);
}
}
#define DBG_PRINT(v) #v << ": " << v
int main(void) {
struct thread_group mythreads;
// the count
int count = 0;
std::cout << "Pre: " << DBG_PRINT(count) << std::endl;
mythreads.data = &count;
mythreads.total_threads = NUM_THREADS;
mythreads.task = thread_task_increment;
do_threading(mythreads);
std::cout << "Post: " << DBG_PRINT(count) << std::endl;
}