diff --git a/src/array_string.rs b/src/array_string.rs
index 523c58f..30aea4a 100644
--- a/src/array_string.rs
+++ b/src/array_string.rs
@@ -171,7 +171,9 @@ impl ArrayString
pub fn try_push(&mut self, c: char) -> Result<(), CapacityError> {
let len = self.len();
unsafe {
- match encode_utf8(c, &mut self.raw_mut_bytes()[len..]) {
+ let ptr = self.xs.ptr_mut().add(len);
+ let remaining_cap = self.capacity() - len;
+ match encode_utf8(c, ptr, remaining_cap) {
Ok(n) => {
self.set_len(len + n);
Ok(())
@@ -346,11 +348,6 @@ impl ArrayString
pub fn as_str(&self) -> &str {
self
}
-
- /// Return a mutable slice of the whole string’s buffer
- unsafe fn raw_mut_bytes(&mut self) -> &mut [u8] {
- slice::from_raw_parts_mut(self.xs.ptr_mut(), self.capacity())
- }
}
impl Deref for ArrayString
diff --git a/src/char.rs b/src/char.rs
index 8191dfb..c9b00ca 100644
--- a/src/char.rs
+++ b/src/char.rs
@@ -10,6 +10,8 @@
//
// Original authors: alexchrichton, bluss
+use std::ptr;
+
// UTF-8 ranges and tags for encoding characters
const TAG_CONT: u8 = 0b1000_0000;
const TAG_TWO_B: u8 = 0b1100_0000;
@@ -22,33 +24,75 @@ const MAX_THREE_B: u32 = 0x10000;
/// Placeholder
pub struct EncodeUtf8Error;
+#[inline]
+unsafe fn write(ptr: *mut u8, index: usize, byte: u8) {
+ ptr::write(ptr.add(index), byte)
+}
+
/// Encode a char into buf using UTF-8.
///
/// On success, return the byte length of the encoding (1, 2, 3 or 4).
/// On error, return `EncodeUtf8Error` if the buffer was too short for the char.
+///
+/// Safety: `ptr` must be writable for `len` bytes.
#[inline]
-pub fn encode_utf8(ch: char, buf: &mut [u8]) -> Result
+pub unsafe fn encode_utf8(ch: char, ptr: *mut u8, len: usize) -> Result
{
let code = ch as u32;
- if code < MAX_ONE_B && buf.len() >= 1 {
- buf[0] = code as u8;
+ if code < MAX_ONE_B && len >= 1 {
+ write(ptr, 0, code as u8);
return Ok(1);
- } else if code < MAX_TWO_B && buf.len() >= 2 {
- buf[0] = (code >> 6 & 0x1F) as u8 | TAG_TWO_B;
- buf[1] = (code & 0x3F) as u8 | TAG_CONT;
+ } else if code < MAX_TWO_B && len >= 2 {
+ write(ptr, 0, (code >> 6 & 0x1F) as u8 | TAG_TWO_B);
+ write(ptr, 1, (code & 0x3F) as u8 | TAG_CONT);
return Ok(2);
- } else if code < MAX_THREE_B && buf.len() >= 3 {
- buf[0] = (code >> 12 & 0x0F) as u8 | TAG_THREE_B;
- buf[1] = (code >> 6 & 0x3F) as u8 | TAG_CONT;
- buf[2] = (code & 0x3F) as u8 | TAG_CONT;
+ } else if code < MAX_THREE_B && len >= 3 {
+ write(ptr, 0, (code >> 12 & 0x0F) as u8 | TAG_THREE_B);
+ write(ptr, 1, (code >> 6 & 0x3F) as u8 | TAG_CONT);
+ write(ptr, 2, (code & 0x3F) as u8 | TAG_CONT);
return Ok(3);
- } else if buf.len() >= 4 {
- buf[0] = (code >> 18 & 0x07) as u8 | TAG_FOUR_B;
- buf[1] = (code >> 12 & 0x3F) as u8 | TAG_CONT;
- buf[2] = (code >> 6 & 0x3F) as u8 | TAG_CONT;
- buf[3] = (code & 0x3F) as u8 | TAG_CONT;
+ } else if len >= 4 {
+ write(ptr, 0, (code >> 18 & 0x07) as u8 | TAG_FOUR_B);
+ write(ptr, 1, (code >> 12 & 0x3F) as u8 | TAG_CONT);
+ write(ptr, 2, (code >> 6 & 0x3F) as u8 | TAG_CONT);
+ write(ptr, 3, (code & 0x3F) as u8 | TAG_CONT);
return Ok(4);
};
Err(EncodeUtf8Error)
}
+
+#[test]
+fn test_encode_utf8() {
+ // Test that all codepoints are encoded correctly
+ let mut data = [0u8; 16];
+ for codepoint in 0..=(std::char::MAX as u32) {
+ if let Some(ch) = std::char::from_u32(codepoint) {
+ for elt in &mut data { *elt = 0; }
+ let ptr = data.as_mut_ptr();
+ let len = data.len();
+ unsafe {
+ let res = encode_utf8(ch, ptr, len).ok().unwrap();
+ assert_eq!(res, ch.len_utf8());
+ }
+ let string = std::str::from_utf8(&data).unwrap();
+ assert_eq!(string.chars().next(), Some(ch));
+ }
+ }
+}
+
+#[test]
+fn test_encode_utf8_oob() {
+ // test that we report oob if the buffer is too short
+ let mut data = [0u8; 16];
+ let chars = ['a', 'α', '�', '𐍈'];
+ for (len, &ch) in (1..=4).zip(&chars) {
+ assert_eq!(len, ch.len_utf8(), "Len of ch={}", ch);
+ let ptr = data.as_mut_ptr();
+ unsafe {
+ assert!(matches::matches!(encode_utf8(ch, ptr, len - 1), Err(_)));
+ assert!(matches::matches!(encode_utf8(ch, ptr, len), Ok(_)));
+ }
+ }
+}
+