blob: e7a3401dc82fde1712ee89f7baac465cc9373173 [file] [log] [blame]
Shawn Willdenf4527742017-11-09 15:59:39 -07001/*
2 * Copyright (C) 2014 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "authorization_set.h"
18
19#include <assert.h>
20#include <stddef.h>
21#include <stdlib.h>
22#include <string.h>
23#include <istream>
24#include <limits>
25#include <ostream>
26
27#include <new>
28
29namespace keystore {
30
31inline bool keyParamLess(const KeyParameter& a, const KeyParameter& b) {
32 if (a.tag != b.tag) return a.tag < b.tag;
33 int retval;
34 switch (typeFromTag(a.tag)) {
35 case TagType::INVALID:
36 case TagType::BOOL:
37 return false;
38 case TagType::ENUM:
39 case TagType::ENUM_REP:
40 case TagType::UINT:
41 case TagType::UINT_REP:
42 return a.f.integer < b.f.integer;
43 case TagType::ULONG:
44 case TagType::ULONG_REP:
45 return a.f.longInteger < b.f.longInteger;
46 case TagType::DATE:
47 return a.f.dateTime < b.f.dateTime;
48 case TagType::BIGNUM:
49 case TagType::BYTES:
50 // Handle the empty cases.
51 if (a.blob.size() == 0) return b.blob.size() != 0;
52 if (b.blob.size() == 0) return false;
53
54 retval = memcmp(&a.blob[0], &b.blob[0], std::min(a.blob.size(), b.blob.size()));
55 // if one is the prefix of the other the longer wins
56 if (retval == 0) return a.blob.size() < b.blob.size();
57 // Otherwise a is less if a is less.
58 else
59 return retval < 0;
60 }
61 return false;
62}
63
64inline bool keyParamEqual(const KeyParameter& a, const KeyParameter& b) {
65 if (a.tag != b.tag) return false;
66
67 switch (typeFromTag(a.tag)) {
68 case TagType::INVALID:
69 case TagType::BOOL:
70 return true;
71 case TagType::ENUM:
72 case TagType::ENUM_REP:
73 case TagType::UINT:
74 case TagType::UINT_REP:
75 return a.f.integer == b.f.integer;
76 case TagType::ULONG:
77 case TagType::ULONG_REP:
78 return a.f.longInteger == b.f.longInteger;
79 case TagType::DATE:
80 return a.f.dateTime == b.f.dateTime;
81 case TagType::BIGNUM:
82 case TagType::BYTES:
83 if (a.blob.size() != b.blob.size()) return false;
84 return a.blob.size() == 0 || memcmp(&a.blob[0], &b.blob[0], a.blob.size()) == 0;
85 }
86 return false;
87}
88
89void AuthorizationSet::Sort() {
90 std::sort(data_.begin(), data_.end(), keyParamLess);
91}
92
93void AuthorizationSet::Deduplicate() {
94 if (data_.empty()) return;
95
96 Sort();
97 std::vector<KeyParameter> result;
98
99 auto curr = data_.begin();
100 auto prev = curr++;
101 for (; curr != data_.end(); ++prev, ++curr) {
102 if (prev->tag == Tag::INVALID) continue;
103
104 if (!keyParamEqual(*prev, *curr)) {
105 result.emplace_back(std::move(*prev));
106 }
107 }
108 result.emplace_back(std::move(*prev));
109
110 std::swap(data_, result);
111}
112
113void AuthorizationSet::Union(const AuthorizationSet& other) {
114 data_.insert(data_.end(), other.data_.begin(), other.data_.end());
115 Deduplicate();
116}
117
118void AuthorizationSet::Subtract(const AuthorizationSet& other) {
119 Deduplicate();
120
121 auto i = other.begin();
122 while (i != other.end()) {
123 int pos = -1;
124 do {
125 pos = find(i->tag, pos);
126 if (pos != -1 && keyParamEqual(*i, data_[pos])) {
127 data_.erase(data_.begin() + pos);
128 break;
129 }
130 } while (pos != -1);
131 ++i;
132 }
133}
134
135int AuthorizationSet::find(Tag tag, int begin) const {
136 auto iter = data_.begin() + (1 + begin);
137
138 while (iter != data_.end() && iter->tag != tag) ++iter;
139
140 if (iter != data_.end()) return iter - data_.begin();
141 return -1;
142}
143
144bool AuthorizationSet::erase(int index) {
145 auto pos = data_.begin() + index;
146 if (pos != data_.end()) {
147 data_.erase(pos);
148 return true;
149 }
150 return false;
151}
152
153KeyParameter& AuthorizationSet::operator[](int at) {
154 return data_[at];
155}
156
157const KeyParameter& AuthorizationSet::operator[](int at) const {
158 return data_[at];
159}
160
161void AuthorizationSet::Clear() {
162 data_.clear();
163}
164
165size_t AuthorizationSet::GetTagCount(Tag tag) const {
166 size_t count = 0;
167 for (int pos = -1; (pos = find(tag, pos)) != -1;) ++count;
168 return count;
169}
170
171NullOr<const KeyParameter&> AuthorizationSet::GetEntry(Tag tag) const {
172 int pos = find(tag);
173 if (pos == -1) return {};
174 return data_[pos];
175}
176
177/**
178 * Persistent format is:
179 * | 32 bit indirect_size |
180 * --------------------------------
181 * | indirect_size bytes of data | this is where the blob data is stored
182 * --------------------------------
183 * | 32 bit element_count | number of entries
184 * | 32 bit elements_size | total bytes used by entries (entries have variable length)
185 * --------------------------------
186 * | elementes_size bytes of data | where the elements are stored
187 */
188
189/**
190 * Persistent format of blobs and bignums:
191 * | 32 bit tag |
192 * | 32 bit blob_length |
193 * | 32 bit indirect_offset |
194 */
195
196struct OutStreams {
197 std::ostream& indirect;
198 std::ostream& elements;
199};
200
201OutStreams& serializeParamValue(OutStreams& out, const hidl_vec<uint8_t>& blob) {
202 uint32_t buffer;
203
204 // write blob_length
205 auto blob_length = blob.size();
206 if (blob_length > std::numeric_limits<uint32_t>::max()) {
207 out.elements.setstate(std::ios_base::badbit);
208 return out;
209 }
210 buffer = blob_length;
211 out.elements.write(reinterpret_cast<const char*>(&buffer), sizeof(uint32_t));
212
213 // write indirect_offset
214 auto offset = out.indirect.tellp();
215 if (offset < 0 || offset > std::numeric_limits<uint32_t>::max() ||
216 uint32_t(offset) + uint32_t(blob_length) < uint32_t(offset)) { // overflow check
217 out.elements.setstate(std::ios_base::badbit);
218 return out;
219 }
220 buffer = offset;
221 out.elements.write(reinterpret_cast<const char*>(&buffer), sizeof(uint32_t));
222
223 // write blob to indirect stream
224 if (blob_length) out.indirect.write(reinterpret_cast<const char*>(&blob[0]), blob_length);
225
226 return out;
227}
228
229template <typename T>
230OutStreams& serializeParamValue(OutStreams& out, const T& value) {
231 out.elements.write(reinterpret_cast<const char*>(&value), sizeof(T));
232 return out;
233}
234
235OutStreams& serialize(TAG_INVALID_t&&, OutStreams& out, const KeyParameter&) {
236 // skip invalid entries.
237 return out;
238}
239template <typename T>
240OutStreams& serialize(T ttag, OutStreams& out, const KeyParameter& param) {
241 out.elements.write(reinterpret_cast<const char*>(&param.tag), sizeof(int32_t));
242 return serializeParamValue(out, accessTagValue(ttag, param));
243}
244
245template <typename... T>
246struct choose_serializer;
247template <typename... Tags>
248struct choose_serializer<MetaList<Tags...>> {
249 static OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
250 return choose_serializer<Tags...>::serialize(out, param);
251 }
252};
253template <>
254struct choose_serializer<> {
255 static OutStreams& serialize(OutStreams& out, const KeyParameter&) { return out; }
256};
257template <TagType tag_type, Tag tag, typename... Tail>
258struct choose_serializer<TypedTag<tag_type, tag>, Tail...> {
259 static OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
260 if (param.tag == tag) {
261 return keystore::serialize(TypedTag<tag_type, tag>(), out, param);
262 } else {
263 return choose_serializer<Tail...>::serialize(out, param);
264 }
265 }
266};
267
268OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
269 return choose_serializer<all_tags_t>::serialize(out, param);
270}
271
272std::ostream& serialize(std::ostream& out, const std::vector<KeyParameter>& params) {
273 std::stringstream indirect;
274 std::stringstream elements;
275 OutStreams streams = {indirect, elements};
276 for (const auto& param : params) {
277 serialize(streams, param);
278 }
279 if (indirect.bad() || elements.bad()) {
280 out.setstate(std::ios_base::badbit);
281 return out;
282 }
283 auto pos = indirect.tellp();
284 if (pos < 0 || pos > std::numeric_limits<uint32_t>::max()) {
285 out.setstate(std::ios_base::badbit);
286 return out;
287 }
288 uint32_t indirect_size = pos;
289 pos = elements.tellp();
290 if (pos < 0 || pos > std::numeric_limits<uint32_t>::max()) {
291 out.setstate(std::ios_base::badbit);
292 return out;
293 }
294 uint32_t elements_size = pos;
295 uint32_t element_count = params.size();
296
297 out.write(reinterpret_cast<const char*>(&indirect_size), sizeof(uint32_t));
298
299 pos = out.tellp();
300 if (indirect_size) out << indirect.rdbuf();
301 assert(out.tellp() - pos == indirect_size);
302
303 out.write(reinterpret_cast<const char*>(&element_count), sizeof(uint32_t));
304 out.write(reinterpret_cast<const char*>(&elements_size), sizeof(uint32_t));
305
306 pos = out.tellp();
307 if (elements_size) out << elements.rdbuf();
308 assert(out.tellp() - pos == elements_size);
309
310 return out;
311}
312
313struct InStreams {
314 std::istream& indirect;
315 std::istream& elements;
316};
317
318InStreams& deserializeParamValue(InStreams& in, hidl_vec<uint8_t>* blob) {
319 uint32_t blob_length = 0;
320 uint32_t offset = 0;
321 in.elements.read(reinterpret_cast<char*>(&blob_length), sizeof(uint32_t));
322 blob->resize(blob_length);
323 in.elements.read(reinterpret_cast<char*>(&offset), sizeof(uint32_t));
324 in.indirect.seekg(offset);
325 in.indirect.read(reinterpret_cast<char*>(&(*blob)[0]), blob->size());
326 return in;
327}
328
329template <typename T>
330InStreams& deserializeParamValue(InStreams& in, T* value) {
331 in.elements.read(reinterpret_cast<char*>(value), sizeof(T));
332 return in;
333}
334
335InStreams& deserialize(TAG_INVALID_t&&, InStreams& in, KeyParameter*) {
336 // there should be no invalid KeyParamaters but if handle them as zero sized.
337 return in;
338}
339
340template <typename T>
341InStreams& deserialize(T&& ttag, InStreams& in, KeyParameter* param) {
342 return deserializeParamValue(in, &accessTagValue(ttag, *param));
343}
344
345template <typename... T>
346struct choose_deserializer;
347template <typename... Tags>
348struct choose_deserializer<MetaList<Tags...>> {
349 static InStreams& deserialize(InStreams& in, KeyParameter* param) {
350 return choose_deserializer<Tags...>::deserialize(in, param);
351 }
352};
353template <>
354struct choose_deserializer<> {
355 static InStreams& deserialize(InStreams& in, KeyParameter*) {
356 // encountered an unknown tag -> fail parsing
357 in.elements.setstate(std::ios_base::badbit);
358 return in;
359 }
360};
361template <TagType tag_type, Tag tag, typename... Tail>
362struct choose_deserializer<TypedTag<tag_type, tag>, Tail...> {
363 static InStreams& deserialize(InStreams& in, KeyParameter* param) {
364 if (param->tag == tag) {
365 return keystore::deserialize(TypedTag<tag_type, tag>(), in, param);
366 } else {
367 return choose_deserializer<Tail...>::deserialize(in, param);
368 }
369 }
370};
371
372InStreams& deserialize(InStreams& in, KeyParameter* param) {
373 in.elements.read(reinterpret_cast<char*>(&param->tag), sizeof(Tag));
374 return choose_deserializer<all_tags_t>::deserialize(in, param);
375}
376
377std::istream& deserialize(std::istream& in, std::vector<KeyParameter>* params) {
378 uint32_t indirect_size = 0;
379 in.read(reinterpret_cast<char*>(&indirect_size), sizeof(uint32_t));
380 std::string indirect_buffer(indirect_size, '\0');
381 if (indirect_buffer.size() != indirect_size) {
382 in.setstate(std::ios_base::badbit);
383 return in;
384 }
385 in.read(&indirect_buffer[0], indirect_buffer.size());
386
387 uint32_t element_count = 0;
388 in.read(reinterpret_cast<char*>(&element_count), sizeof(uint32_t));
389 uint32_t elements_size = 0;
390 in.read(reinterpret_cast<char*>(&elements_size), sizeof(uint32_t));
391
392 std::string elements_buffer(elements_size, '\0');
393 if (elements_buffer.size() != elements_size) {
394 in.setstate(std::ios_base::badbit);
395 return in;
396 }
397 in.read(&elements_buffer[0], elements_buffer.size());
398
399 if (in.bad()) return in;
400
401 // TODO write one-shot stream buffer to avoid copying here
402 std::stringstream indirect(indirect_buffer);
403 std::stringstream elements(elements_buffer);
404 InStreams streams = {indirect, elements};
405
406 params->resize(element_count);
407
408 for (uint32_t i = 0; i < element_count; ++i) {
409 deserialize(streams, &(*params)[i]);
410 }
411 return in;
412}
413void AuthorizationSet::Serialize(std::ostream* out) const {
414 serialize(*out, data_);
415}
416void AuthorizationSet::Deserialize(std::istream* in) {
417 deserialize(*in, &data_);
418}
419
420} // namespace keystore