多维比较算法

先说结论

失败设计, 性能大幅度劣化

介绍

使用set/map需要一个对储存元素比较大小的算法,很多时候会用上向量、矩阵的储存。一般来说是这样

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
struct vec4 {
double x,y,z,w;

bool operator<(cosnt vec4& r) const {
if(x < r.x) return true;
if(x > r.x) return false;

if(y < r.y) return true;
if(y > r.y) return false;

if(z < r.z) return true;
if(z > r.z) return false;

if(w < r.x) return true;
return false;
}
};

需要多次比较,对CPU来说不够友好,特别是通常来说一次查找需要多次比较
Boost里面给出的方法是多次hash,而且不保证严格正确

1
2
3
4
5
6
7
8
size_t seed = 0;

for(; first != last; ++first)
{
hash_combine(seed, *first);
}

return seed;

自己想了一个算法。大概就是每一个维度,抽象的视为一位数,vec4就是一个4位数,vec5就是一个5位数,而比较大小,则可以两个对象抽象成的两个数字,逐位比较大小,当小于的情况出现比大于早时,就是A小于B。其实和最上面的方法很像,但是不立刻放回,而是储存所有的结果,再一次性比较返回,主要优化的地方在减少了分支预测(未测试)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
template<size_t D>
bool operator<(const vec<D>& A, const vec<D>& B)
{
size_t min = 0, max = 0;
for(size_t i = 0; i != D; ++i) // 由编译器做展开优化
{
// 用三元操作符消去分支预测的影响
min = (l < r) ? min | (1<<(8-i)) : min;
max = (l > r) ? max | (1<<(8-i)) : max;
}

return min > max;
}

代码测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include <iostream>
#include <random>
#include <set>
#include <vector>
#include <chrono>

constexpr size_t test_count = 1e6;

struct Mat3
{
double v[9];

const double& operator[](size_t i) const {return v[i];}
double& operator[](size_t i) {return v[i]; }
bool operator!=(const Mat3& r) const {
return !(v[0] == r.v[0] && v[1] == r.v[1] && v[2] == r.v[2]
&& v[3] == r.v[3] && v[4] == r.v[4] &&v[5] == r.v[5]
&& v[6] == r.v[6] && v[7] == r.v[7] && v[8] == r.v[8]);
}
};

struct normal_compare
{
bool operator()(const Mat3& l, const Mat3& r) const{
#define normal_compare_CompareValue(l,r) {if((l) < (r)) return true; if((l) > (r)) return false;}
normal_compare_CompareValue(l[0], r[0]);
normal_compare_CompareValue(l[1], r[1]);
normal_compare_CompareValue(l[2], r[2]);
normal_compare_CompareValue(l[3], r[3]);
normal_compare_CompareValue(l[4], r[4]);
normal_compare_CompareValue(l[5], r[5]);
normal_compare_CompareValue(l[6], r[6]);
normal_compare_CompareValue(l[7], r[7]);
normal_compare_CompareValue(l[8], r[8]);

#undef normal_compare_CompareValue
return false;
}
};

struct bit_compare
{
bool operator()(const Mat3& lv, const Mat3& rv) const{
size_t min = 0, max = 0;

#define bit_compare_CombineValue(l, r, i) { \
min = (l < r) ? min | (1<<(8-i)) : min; \
max = (l > r) ? max | (1<<(8-i)) : max;\
}

bit_compare_CombineValue(lv[0], rv[0], 0);
bit_compare_CombineValue(lv[1], rv[1], 1);
bit_compare_CombineValue(lv[2], rv[2], 2);
bit_compare_CombineValue(lv[3], rv[3], 3);
bit_compare_CombineValue(lv[4], rv[4], 4);
bit_compare_CombineValue(lv[5], rv[5], 5);
bit_compare_CombineValue(lv[6], rv[6], 6);
bit_compare_CombineValue(lv[7], rv[7], 7);
bit_compare_CombineValue(lv[8], rv[8], 8);
#undef bit_compare_CombineValue

return min > max;
}
};

auto now(){
return std::chrono::high_resolution_clock::now();
}

auto get_use_time_ms(const std::chrono::nanoseconds& nanoseconds){
using namespace std::chrono;
return duration_cast<milliseconds>(nanoseconds).count();
}

int main()
{
std::cout << "start" << std::endl;
auto start_clock = now();

std::random_device r;
std::default_random_engine e1(r());
std::uniform_real_distribution<double> uniform_dist(-1000, 1000);

std::vector<Mat3> mat_vector(test_count);
for(auto& m : mat_vector){
for(size_t i=0; i<9; ++i){
m[i] = uniform_dist(e1);
}
}
auto init_cost = get_use_time_ms(now() - start_clock);
std::cout << "init " << test_count << " random Mat3*3 use " << init_cost << " ms" << std::endl;

std::set<Mat3, normal_compare> normal_set;
auto start_insert_with_normal = now();
for(auto& e: mat_vector){
normal_set.insert(e);
}
auto end_insert_with_normal = now();
auto normal_cost = get_use_time_ms(end_insert_with_normal - start_insert_with_normal);
std::cout << "insert with normal func use " << normal_cost << " ms" << std::endl;

std::set<Mat3, bit_compare> bitcompare_set;
auto start_insert_with_bitcompare = now();
for(auto& e: mat_vector){
bitcompare_set.insert(e);
}
auto end_insert_with_bitcompare = now();
auto bitcompare_cost = get_use_time_ms(end_insert_with_bitcompare - start_insert_with_bitcompare);
std::cout << "insert with bitcompare func use " << bitcompare_cost << " ms" << std::endl;

std::cout << "bitcompare use " << (double)bitcompare_cost / (double)normal_cost * 100 << "% of normal func" << std::endl;

std::cout << "start validate insert result" << std::endl;

if(normal_set.size() != bitcompare_set.size()){
std::cout << "normal_set size: " << normal_set.size() << std::endl;
std::cout << "bitcompare_set size: " << bitcompare_set.size() << std::endl;
std::cout << "vvvERRORvvv\nset size no equal\n^^^ERROR^^^" << std::endl;
goto END;
}
{
auto it_n = normal_set.begin();
auto it_b = bitcompare_set.begin();
for(;it_n != normal_set.end(); ++it_n,++it_b){
if((*it_n) != (*it_b)){
std::cout << "vvvERRORvvv\ndata no equal\n^^^ERROR^^^" << std::endl;
goto END;
}
}
}
std::cout << "validate insert result successed" << std::endl;
END:
{
std::cout << "total use " << get_use_time_ms(now() - start_clock) << " ms" << std::endl;
std::cout << "end" << std::endl;
}

return 0;
}

简单运行的结果

1
2
3
4
5
6
7
8
9
start
init 1000000 random Mat3*3 use 281 ms
insert with normal func use 780 ms
insert with bitcompare func use 933 ms
bitcompare use 119.615% of normal func
start validate insert result
validate insert result successed
total use 2114 ms
end

分析

在这个 18 个比较中, 普通的比较虽然有 if…return 这样改变流水线的指令, 但是可能直接结束函数是跳指令, 不需要重排指令流水线, 对执行效率的影响很低. 而如果完整的进行 18 次 double 的比较和或操作, 带来的性能消耗会更大.

todo

学习查看汇编指令, 更加深入的分析