FIX: Support uninitalized data in encode_utf8, and update try_push

We were using &mut [u8] in encode_utf8, but this is not right according
to the developing unsafe coding guidelines. We need to use raw pointers
to write to possibly uninit memory.

We use a raw pointer form for encode_utf8. It was first attempted to
encapsulate the trusted-to-be-valid raw pointer in a simple { *mut u8,
usize } struct, but the current way of passing ptr and len separately
was the only way to not regress performance.

This impl maintains the same performance in arraystring benches.

Add exhaustive-style tests for encode_utf8
This commit is contained in:
bluss
2019-10-09 10:50:38 +02:00
parent f665142854
commit 090a5c50cb
2 changed files with 62 additions and 21 deletions
+3 -6
View File
@@ -171,7 +171,9 @@ impl<A> ArrayString<A>
pub fn try_push(&mut self, c: char) -> Result<(), CapacityError<char>> { pub fn try_push(&mut self, c: char) -> Result<(), CapacityError<char>> {
let len = self.len(); let len = self.len();
unsafe { unsafe {
match encode_utf8(c, &mut self.raw_mut_bytes()[len..]) { let ptr = self.xs.ptr_mut().add(len);
let remaining_cap = self.capacity() - len;
match encode_utf8(c, ptr, remaining_cap) {
Ok(n) => { Ok(n) => {
self.set_len(len + n); self.set_len(len + n);
Ok(()) Ok(())
@@ -346,11 +348,6 @@ impl<A> ArrayString<A>
pub fn as_str(&self) -> &str { pub fn as_str(&self) -> &str {
self self
} }
/// Return a mutable slice of the whole strings buffer
unsafe fn raw_mut_bytes(&mut self) -> &mut [u8] {
slice::from_raw_parts_mut(self.xs.ptr_mut(), self.capacity())
}
} }
impl<A> Deref for ArrayString<A> impl<A> Deref for ArrayString<A>
+59 -15
View File
@@ -10,6 +10,8 @@
// //
// Original authors: alexchrichton, bluss // Original authors: alexchrichton, bluss
use std::ptr;
// UTF-8 ranges and tags for encoding characters // UTF-8 ranges and tags for encoding characters
const TAG_CONT: u8 = 0b1000_0000; const TAG_CONT: u8 = 0b1000_0000;
const TAG_TWO_B: u8 = 0b1100_0000; const TAG_TWO_B: u8 = 0b1100_0000;
@@ -22,33 +24,75 @@ const MAX_THREE_B: u32 = 0x10000;
/// Placeholder /// Placeholder
pub struct EncodeUtf8Error; pub struct EncodeUtf8Error;
#[inline]
unsafe fn write(ptr: *mut u8, index: usize, byte: u8) {
ptr::write(ptr.add(index), byte)
}
/// Encode a char into buf using UTF-8. /// Encode a char into buf using UTF-8.
/// ///
/// On success, return the byte length of the encoding (1, 2, 3 or 4).<br> /// On success, return the byte length of the encoding (1, 2, 3 or 4).<br>
/// On error, return `EncodeUtf8Error` if the buffer was too short for the char. /// On error, return `EncodeUtf8Error` if the buffer was too short for the char.
///
/// Safety: `ptr` must be writable for `len` bytes.
#[inline] #[inline]
pub fn encode_utf8(ch: char, buf: &mut [u8]) -> Result<usize, EncodeUtf8Error> pub unsafe fn encode_utf8(ch: char, ptr: *mut u8, len: usize) -> Result<usize, EncodeUtf8Error>
{ {
let code = ch as u32; let code = ch as u32;
if code < MAX_ONE_B && buf.len() >= 1 { if code < MAX_ONE_B && len >= 1 {
buf[0] = code as u8; write(ptr, 0, code as u8);
return Ok(1); return Ok(1);
} else if code < MAX_TWO_B && buf.len() >= 2 { } else if code < MAX_TWO_B && len >= 2 {
buf[0] = (code >> 6 & 0x1F) as u8 | TAG_TWO_B; write(ptr, 0, (code >> 6 & 0x1F) as u8 | TAG_TWO_B);
buf[1] = (code & 0x3F) as u8 | TAG_CONT; write(ptr, 1, (code & 0x3F) as u8 | TAG_CONT);
return Ok(2); return Ok(2);
} else if code < MAX_THREE_B && buf.len() >= 3 { } else if code < MAX_THREE_B && len >= 3 {
buf[0] = (code >> 12 & 0x0F) as u8 | TAG_THREE_B; write(ptr, 0, (code >> 12 & 0x0F) as u8 | TAG_THREE_B);
buf[1] = (code >> 6 & 0x3F) as u8 | TAG_CONT; write(ptr, 1, (code >> 6 & 0x3F) as u8 | TAG_CONT);
buf[2] = (code & 0x3F) as u8 | TAG_CONT; write(ptr, 2, (code & 0x3F) as u8 | TAG_CONT);
return Ok(3); return Ok(3);
} else if buf.len() >= 4 { } else if len >= 4 {
buf[0] = (code >> 18 & 0x07) as u8 | TAG_FOUR_B; write(ptr, 0, (code >> 18 & 0x07) as u8 | TAG_FOUR_B);
buf[1] = (code >> 12 & 0x3F) as u8 | TAG_CONT; write(ptr, 1, (code >> 12 & 0x3F) as u8 | TAG_CONT);
buf[2] = (code >> 6 & 0x3F) as u8 | TAG_CONT; write(ptr, 2, (code >> 6 & 0x3F) as u8 | TAG_CONT);
buf[3] = (code & 0x3F) as u8 | TAG_CONT; write(ptr, 3, (code & 0x3F) as u8 | TAG_CONT);
return Ok(4); return Ok(4);
}; };
Err(EncodeUtf8Error) Err(EncodeUtf8Error)
} }
#[test]
fn test_encode_utf8() {
// Test that all codepoints are encoded correctly
let mut data = [0u8; 16];
for codepoint in 0..=(std::char::MAX as u32) {
if let Some(ch) = std::char::from_u32(codepoint) {
for elt in &mut data { *elt = 0; }
let ptr = data.as_mut_ptr();
let len = data.len();
unsafe {
let res = encode_utf8(ch, ptr, len).ok().unwrap();
assert_eq!(res, ch.len_utf8());
}
let string = std::str::from_utf8(&data).unwrap();
assert_eq!(string.chars().next(), Some(ch));
}
}
}
#[test]
fn test_encode_utf8_oob() {
// test that we report oob if the buffer is too short
let mut data = [0u8; 16];
let chars = ['a', 'α', '', '𐍈'];
for (len, &ch) in (1..=4).zip(&chars) {
assert_eq!(len, ch.len_utf8(), "Len of ch={}", ch);
let ptr = data.as_mut_ptr();
unsafe {
assert!(matches::matches!(encode_utf8(ch, ptr, len - 1), Err(_)));
assert!(matches::matches!(encode_utf8(ch, ptr, len), Ok(_)));
}
}
}