-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_pickle_parser.cpp
More file actions
127 lines (102 loc) · 3.18 KB
/
test_pickle_parser.cpp
File metadata and controls
127 lines (102 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <cassert>
#include <cstdint>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <stdexcept>
#include <string>
#include <vector>
#include "model_loader.h"
#include "pickle_parser.h"
namespace {
void append_u64(std::vector<uint8_t> & out, uint64_t value) {
for (int i = 0; i < 8; ++i) {
out.push_back(static_cast<uint8_t>((value >> (8 * i)) & 0xFFU));
}
}
void append_global(std::vector<uint8_t> & out, const std::string & module, const std::string & name) {
out.push_back('c');
out.insert(out.end(), module.begin(), module.end());
out.push_back('\n');
out.insert(out.end(), name.begin(), name.end());
out.push_back('\n');
}
void append_short_unicode(std::vector<uint8_t> & out, const std::string & value) {
if (value.size() > 255) {
throw std::runtime_error("value too large for SHORT_BINUNICODE");
}
out.push_back('\x8c');
out.push_back(static_cast<uint8_t>(value.size()));
out.insert(out.end(), value.begin(), value.end());
}
std::vector<uint8_t> make_tensor_pickle() {
std::vector<uint8_t> p;
p.push_back('\x80');
p.push_back('\x04');
p.push_back('\x95');
append_u64(p, 0);
p.push_back('}');
p.push_back('\x94');
p.push_back('(');
append_short_unicode(p, "w");
append_global(p, "torch._utils", "_rebuild_tensor_v2");
p.push_back('(');
p.push_back('(');
append_short_unicode(p, "storage");
append_global(p, "torch", "FloatStorage");
append_short_unicode(p, "0");
append_short_unicode(p, "cpu");
p.push_back('K');
p.push_back(4);
p.push_back('t');
p.push_back('Q');
p.push_back('K');
p.push_back(0);
p.push_back('(');
p.push_back('K');
p.push_back(2);
p.push_back('K');
p.push_back(2);
p.push_back('t');
p.push_back('(');
p.push_back('K');
p.push_back(2);
p.push_back('K');
p.push_back(1);
p.push_back('t');
p.push_back('\x88');
p.push_back('}');
p.push_back('t');
p.push_back('R');
p.push_back('u');
p.push_back('.');
return p;
}
void test_parse_tensor_with_proto4_ops() {
const std::vector<uint8_t> payload = make_tensor_pickle();
const std::vector<gd::TensorDescriptor> tensors = gd::parse_tensor_descriptors(payload);
assert(tensors.size() == 1);
assert(tensors[0].name == "w");
assert(tensors[0].storage_key == "0");
assert(tensors[0].shape.size() == 2);
assert(tensors[0].shape[0] == 2);
assert(tensors[0].shape[1] == 2);
}
void test_find_model_candidates_extension_filter() {
const std::filesystem::path tmp = std::filesystem::temp_directory_path() / "gd_candidates_test";
std::filesystem::create_directories(tmp);
std::ofstream(tmp / "a.pt").put('\n');
std::ofstream(tmp / "b.ckpt").put('\n');
std::ofstream(tmp / "c.pth").put('\n');
std::ofstream(tmp / "ignore.txt").put('\n');
const std::vector<gd::fs::path> candidates = gd::find_model_candidates(tmp);
assert(candidates.size() == 3);
std::filesystem::remove_all(tmp);
}
} // namespace
int main() {
test_parse_tensor_with_proto4_ops();
test_find_model_candidates_extension_filter();
std::cout << "All tests passed\n";
return 0;
}