tread-pool 阅读笔记

2022/1/13 阅读笔记

TODO

# FunctionTraits.h

#include <tuple>
#include <type_traits>

/* 定义基础模板类 */
template <typename T>
struct function_traits;

/* 定义辅助类型 */
template <typename R, typename... Args>
struct function_traits_helper
{
	static constexpr int param_count = sizeof...(Args);
	using return_type = R;

	template<std::size_t N>
	using args_type = std::tuple_element_t<N, std::tuple<Args...>>;
};

/* 特化函数类型 */
template <typename R, typename... Args>
struct function_traits<R(Args...)> : public function_traits_helper<R, Args...> {};

/* 特化函数引用类型 */
template <typename R, typename... Args>
struct function_traits<R(&)(Args...)> : public function_traits_helper<R, Args...> {};

/* 特化函数指针类型 */
template <typename R, typename... Args>
struct function_traits<R(*)(Args...)> : public function_traits_helper<R, Args...> {};

/* 特化函数对象,包括lambda表达式、std::function对象、std::bind表达式对象、自定义函子等 */
template <typename ClassType, typename R, typename... Args>
struct function_traits<R(ClassType::*)(Args...) const> : public function_traits_helper<R, Args...>
{
	using class_type = ClassType;
};

template <typename ClassType, typename R, typename... Args>
struct function_traits<R(ClassType::*)(Args...)> : public function_traits_helper<R, Args...>
{
	using class_type = ClassType;
};

/* 对函数对象进行operator()展开 */
template <typename T>
struct function_traits : public function_traits<decltype(&std::remove_reference_t<T>::operator())> {};

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

这里最开始的的时候有些疑惑 R(Args...) 是什么东西,怀疑是函数指针,后来了解到这就是函数指针的特化,提取函数参数的类型,该函数被声明为接受 Arg 类型的多个参数,并返回 R 类型的值。以及上方的基类 template <typename T> struct function_traits; 看起来啥也没干,其实就是啥也没干,类似解包参数的终止模板。

手动控制编译器在编译期解包参数,以及灵活的 c++ 多参数写法。但是有以下疑问:

  1. decltype(&std::remove_reference_t<T>::operator()) 啥作用

Update:对函数对象解引用之后,调用 operator() 是为了获取返回值,然后利用 decltype 得到返回值类型;如果直接看 std::remove_reference - cppreference.com (opens new window) 则发现 remove_reference_t 是没有 operator() 的,因为这个 T 的,即传入函数 Toperator()

值得注意的地方

    /* 禁止拷贝、赋值*/
    ThreadPool(const ThreadPool&) = delete;
    ThreadPool& operator=(const ThreadPool&) = delete;
1
2
3

# Concurrent.h

#include "ThreadPool.h"
#include "FunctionTraits.h"

namespace Fate {
	
	//----------------------------------并发主体类型-----------------------------------
	class Concurrent {
	public:
		// map() on sequences
		template <typename T, typename MapFunctor, template <typename T, typename Alloc = std::allocator<T> > class InputSequence,
			typename R = typename std::decay<typename function_traits<MapFunctor>::return_type>::type>
		static std::future<void> map(InputSequence<T> &sequence, MapFunctor map)
		{
			auto func = [&sequence, &map]() {
				std::shared_ptr<ThreadPool> _pool = ThreadPool::instance();

				std::vector<std::future<R> > results;
				for (auto it = sequence.begin(); it != sequence.end(); ++it)
				{
					results.emplace_back(_pool->enqueue(map, *it));
				}
				for (auto&& result : results)
				{
					result.get();
				}
				return;
			};
			
			return ThreadPool::run(func);
		}

		// mapped() for sequences
		template <typename T, typename MapFunctor, template <typename T, typename Alloc = std::allocator<T> > class InputSequence,
			typename R = typename std::decay<typename function_traits<MapFunctor>::return_type>::type>
		static std::future<InputSequence<R> > mapped(InputSequence<T> &sequence, MapFunctor map)
		{
			auto func = [&sequence, &map]() ->InputSequence<R> {
				std::shared_ptr<ThreadPool> _pool = ThreadPool::instance();

				std::vector<std::future<R> > results;
				for (auto it = sequence.begin(); it != sequence.end(); ++it)
				{
					results.emplace_back(_pool->enqueue(map, *it));
				}
				InputSequence<R> ret;
				for (auto&& result : results)
				{
					ret.push_back(result.get());
				}
				return ret;
			};
			return ThreadPool::run(func);
		}

		// mappedReduced() for sequences.
		template <typename T, typename MapFunctor, typename ReduceFunctor, typename... Args,
			template <typename T, typename Alloc = std::allocator<T> > class InputSequence,
			typename R = typename std::decay<typename function_traits<MapFunctor>::return_type>::type,
			typename Arg0 = typename std::decay<typename function_traits<ReduceFunctor>::template args_type<0>>::type>
		static std::future<Arg0> mappedReduced(InputSequence<T> &sequence, MapFunctor map, ReduceFunctor reduce, Args&&... args)
		{
			Arg0 initialValue(std::forward<Args>(args)...);
			auto func = [&sequence, &map, &reduce, initialValue]() ->Arg0 {
				std::shared_ptr<ThreadPool> _pool = ThreadPool::instance();
			
				std::vector<std::future<R> > results;
				for (auto it = sequence.begin(); it != sequence.end(); ++it)
				{
					results.emplace_back(_pool->enqueue(map, *it));
				}
				Arg0 ret = initialValue;
				for (auto&& result : results)
				{
					reduce(ret, result.get());
				}
				return ret;
			};
			return ThreadPool::run(func);
		}
	};
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

这里主要就是利用模板参数为每一个元素执行函数动作

# ThreadPool.cpp / ThreadPool.h

#include "FunctionTraits.h"
#include <vector>
#include <deque>
#include <memory>
#include <thread>
#include <mutex>
#include <atomic>
#include <condition_variable>
#include <future>
#include <functional>

namespace Fate {

	/* 自旋锁 */
	/*class SpinLock {
	public:
		inline void lock() {
			while (m_flag.test_and_set()) {
				std::this_thread::yield();
			}
		}
		inline void unlock() { m_flag.clear(); }

	private:
		std::atomic_flag m_flag = ATOMIC_FLAG_INIT;
	};*/

	/* 线程池 */
	class ThreadPool {
	public:
		/* 析构 */
		~ThreadPool();

		/* 线程池主备类型 */
		enum class ThreadPoolType {
			MASTER = 0,
			SLAVE
		};

		/* 线程状态 */
		enum class ThreadState {
			BLOCKING = 0,
			READY,
			RUNNING
		};

	private:
		/* 当threads小于等于0时, threads为cpu + 1*/
		ThreadPool(int threads = 0, ThreadPoolType type = ThreadPoolType::MASTER);
		
		/* 禁止拷贝、赋值*/
		ThreadPool(const ThreadPool&) = delete;
		ThreadPool& operator=(const ThreadPool&) = delete;

		/* 构建一个新的线程池 */
		static std::shared_ptr<ThreadPool> create(int threads = 0, ThreadPoolType type = ThreadPoolType::MASTER);

	private:
		/* 备用线程数量 */
		static int s_slaveThreadNum;

		/* 线程状态相关同步变量 */
		static int s_states;
		static std::mutex s_stateMutex;
		static std::condition_variable s_stateCondition;

		/* 任务队列相关同步变量 */
		static std::deque<std::function<void()> > s_tasks;
		static std::mutex s_taskMutex;
		static std::condition_variable s_taskCondition;

		/* 强行停止标识符 */
		static bool s_stop;

	public:
		/* 公共线程池 */
		static std::shared_ptr<ThreadPool> instance(int threads = 0);

		/* 当前线程是否属于线程池 */
		static bool isWorkerThread();

		/* 获取当前线程状态 */
		static ThreadPool::ThreadState getCurrentThreadState();

		/* 设置当前线程状态 */
		static void setCurrentThreadState(ThreadPool::ThreadState state);

		/* 添加任务 */
		template<class F, class... Args, typename R = typename std::decay<typename function_traits<F>::return_type>::type>
		std::future<R> enqueue(F&& f, Args&&... args);

		/* 运行一个线程 */
		template<class F, class... Args, typename R = typename std::decay<typename function_traits<F>::return_type>::type>
		static std::future<R> runWithPool(std::shared_ptr<ThreadPool> pool, F&& f, Args&&... args);

		template<class F, class... Args, typename R = typename std::decay<typename function_traits<F>::return_type>::type>
		static std::future<R> run(F&& f, Args&&... args);

	private:
		/* 线程队列(消费者队列)*/
		std::vector<std::thread> m_workers;
		/* 主备标识 */
		ThreadPoolType m_type;
		/* 停止标识 */
		bool m_stop;
		/* 备用线程池 */
		std::shared_ptr<ThreadPool> m_slaveThreadPool;
	};

	/* 添加任务 */
	template<class F, class... Args, typename R>
	std::future<R> ThreadPool::enqueue(F&& f, Args&&... args)
	{
		auto task = std::make_shared<std::packaged_task<R()> >(
			std::bind(std::forward<F>(f), std::ref(std::forward<Args>(args))...)
			);

		std::future<R> ret = task->get_future();

		/* 如果是从线程池中线程发起的任务,设置为BLOCKING状态,从备用线程中挑选一个开始运行 */
		std::shared_ptr<ThreadPool> needDeletePool;
		if (ThreadPool::isWorkerThread() && (ThreadPool::getCurrentThreadState() != ThreadState::BLOCKING))
		{
			ThreadPool::setCurrentThreadState(ThreadState::BLOCKING);
			std::unique_lock<std::mutex> lock(s_stateMutex);
			/* 如果备用线程不足,生成新的备用线程池 */
			if (s_slaveThreadNum <= 0)
			{
				if (m_slaveThreadPool)
				{
					/* 旧线程池释放任务放进线程池,这是个BLOCKING任务 */
					needDeletePool = m_slaveThreadPool;
					++s_states;
					--s_slaveThreadNum;
					s_stateCondition.notify_one();
				}
				m_slaveThreadPool = ThreadPool::create(m_workers.size(), ThreadPoolType::SLAVE);
			}
			/* 当前线程转换为BLOCKING状态后,通知备用线程改变状态 */
			++s_states;
			--s_slaveThreadNum;
			s_stateCondition.notify_one();
		}
		{
			/* 将任务放进任务队列 */
			std::unique_lock<std::mutex> lock(s_taskMutex);
			if (ThreadPool::isWorkerThread())
				s_tasks.emplace_front([task] { (*task)(); });
			else
				s_tasks.emplace_back([task] { (*task)(); });
			s_taskCondition.notify_one();
			/* 检查旧线程池是否需要释放 */
			if (needDeletePool)
			{
				auto deleteTask = [needDeletePool] {
					ThreadPool::setCurrentThreadState(ThreadState::BLOCKING);
					((std::shared_ptr<ThreadPool>)needDeletePool).reset();
				};
				s_tasks.emplace_back(deleteTask);
				s_taskCondition.notify_one();
			}
		}

		return ret;
	}

	/* 运行一个线程 */
	template<class F, class... Args, typename R>
	static std::future<R> ThreadPool::runWithPool(std::shared_ptr<ThreadPool> pool, F&& f, Args&&... args)
	{
		std::shared_ptr<ThreadPool> ins = pool;
		if (!ins)
			ins = ThreadPool::instance();

		return ins->enqueue(std::forward<F>(f), std::forward<Args>(args)...);
	}

	template<class F, class... Args, typename R>
	static std::future<R> ThreadPool::run(F&& f, Args&&... args)
	{
		std::shared_ptr<ThreadPool> ins = ThreadPool::instance();
		return ThreadPool::runWithPool(ins, std::forward<F>(f), std::forward<Args>(args)...);
	}

	/* 线程池内线程的TLS数据 */
	class ThreadPoolTLS {
	public:
		/* 当前线程属于线程池标识 */
		thread_local static bool t_threadPoolFlag;
		/* 当前线程状态 */
		thread_local static ThreadPool::ThreadState t_threadState;
	};
}


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#include "ThreadPool.h"

namespace Fate {

	thread_local bool ThreadPoolTLS::t_threadPoolFlag = false;
	thread_local ThreadPool::ThreadState ThreadPoolTLS::t_threadState = ThreadPool::ThreadState::READY;
	int ThreadPool::s_slaveThreadNum = 0;
	int ThreadPool::s_states = 0;
	std::mutex ThreadPool::s_stateMutex;
	std::condition_variable ThreadPool::s_stateCondition;
	std::deque<std::function<void()> > ThreadPool::s_tasks;
	std::mutex ThreadPool::s_taskMutex;
	std::condition_variable ThreadPool::s_taskCondition;
	bool ThreadPool::s_stop = false;

	ThreadPool::ThreadPool(int threads, ThreadPoolType type):
		m_stop(false)
	{
		/* 默认线程数为核心数 + 1*/
		if (threads <= 0)
			threads = std::thread::hardware_concurrency() + 1;
		/* 如果为备用线程池,备用线程数增加 */
		if (type == ThreadPoolType::SLAVE)
			s_slaveThreadNum += threads;
		m_type = type;

		/* 循环创建线程 */
		for (int i = 0; i < threads; ++i)
		{
			auto worker = [this, type] {
				ThreadPoolTLS::t_threadPoolFlag = true;
				if (type == ThreadPoolType::MASTER)
					ThreadPoolTLS::t_threadState = ThreadState::RUNNING;

				while (true)
				{
					std::function<void()> task;
					/* 非是运行线程,就是备用线程*/
					if (ThreadPoolTLS::t_threadState != ThreadState::RUNNING)
					{
						std::unique_lock<std::mutex> lock(s_stateMutex);
						/* 如果当前线程是BLOCKING,转变成READY模式 */
						if (ThreadPoolTLS::t_threadState == ThreadState::BLOCKING)
						{
							ThreadPoolTLS::t_threadState = ThreadState::READY;
							++s_slaveThreadNum;
						}
						s_stateCondition.wait(lock, [this] {
							return m_stop || s_states > 0;
						});
						if (m_stop && s_states <= 0)
						{
							--s_slaveThreadNum;
							break;
						}
						--s_states;
						ThreadPoolTLS::t_threadState = ThreadState::RUNNING;
					}
					/* 运行线程监听任务队列 */
					if (ThreadPoolTLS::t_threadState == ThreadState::RUNNING)
					{
						std::unique_lock<std::mutex> lock(s_taskMutex);
						s_taskCondition.wait(lock, [this] {
							return s_stop || !(s_tasks.empty());
						});
						if (s_stop && s_tasks.empty())
							break;				
						task = std::move(s_tasks.front());
						s_tasks.pop_front();
					}

					/* 执行任务 */
					task();
				}
			};

			m_workers.emplace_back(worker);
		}
	}

	ThreadPool::~ThreadPool()
	{
		m_stop = true;
		s_stateCondition.notify_all();
		if (!ThreadPoolTLS::t_threadPoolFlag)
		{
			s_stop = true;
			s_taskCondition.notify_all();
		}
		for (std::thread& worker : m_workers)
			worker.join();
	}

	std::shared_ptr<ThreadPool> ThreadPool::create(int threads, ThreadPoolType type)
	{
		return std::shared_ptr<ThreadPool>(new ThreadPool(threads, type));
	}

	std::shared_ptr<ThreadPool> ThreadPool::instance(int threads)
	{
		static std::shared_ptr<ThreadPool> s_singleton;
		static std::mutex s_singletonMutex;

		if (!s_singleton)
		{
			std::unique_lock<std::mutex> lock(s_singletonMutex);
			if (!s_singleton)
			{
				s_singleton = std::shared_ptr<ThreadPool>(new ThreadPool(threads));
			}
		}
		return s_singleton;
	}

	bool ThreadPool::isWorkerThread()
	{
		return ThreadPoolTLS::t_threadPoolFlag;
	}

	ThreadPool::ThreadState ThreadPool::getCurrentThreadState()
	{
		return ThreadPoolTLS::t_threadState;
	}

	void ThreadPool::setCurrentThreadState(ThreadPool::ThreadState state)
	{
		ThreadPoolTLS::t_threadState = state;
	}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

thread 有且只有 thread_local 关键字修饰的变量具有线程(thread)周期,这些变量在线程开始的时候被生成,在线程结束的时候被销毁,并且每一个线程都拥有一个独立的变量实例。thread_local 可以和 staticextern 关键字联合使用,这将影响变量的链接属性(to adjust linkage)。

关于储存周期可以参考 Storage class specifiers - cppreference (opens new window)

Last Updated: 2023-10-29T08:26:04.000Z