[ruby/prism] Strip out old char unescaping

27ca207ab3
This commit is contained in:
Kevin Newton 2023-10-10 11:28:41 -04:00
parent dd3986876a
commit 3dba3ab47d
4 changed files with 52 additions and 78 deletions

View file

@ -6215,8 +6215,8 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, uint8_t flags) {
return; return;
} }
case 'x': { case 'x': {
uint8_t byte = peek(parser);
parser->current.end++; parser->current.end++;
uint8_t byte = peek(parser);
if (pm_char_is_hexadecimal_digit(byte)) { if (pm_char_is_hexadecimal_digit(byte)) {
uint8_t value = escape_hexadecimal_digit(byte); uint8_t value = escape_hexadecimal_digit(byte);
@ -6239,7 +6239,7 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, uint8_t flags) {
parser->current.end++; parser->current.end++;
if ( if (
(parser->current.end + 4 < parser->end) && (parser->current.end + 4 <= parser->end) &&
pm_char_is_hexadecimal_digit(parser->current.end[0]) && pm_char_is_hexadecimal_digit(parser->current.end[0]) &&
pm_char_is_hexadecimal_digit(parser->current.end[1]) && pm_char_is_hexadecimal_digit(parser->current.end[1]) &&
pm_char_is_hexadecimal_digit(parser->current.end[2]) && pm_char_is_hexadecimal_digit(parser->current.end[2]) &&
@ -6250,12 +6250,13 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, uint8_t flags) {
parser->current.end += 4; parser->current.end += 4;
} else if (peek(parser) == '{') { } else if (peek(parser) == '{') {
const uint8_t *unicode_codepoints_start = parser->current.end - 2; const uint8_t *unicode_codepoints_start = parser->current.end - 2;
parser->current.end++; parser->current.end++;
parser->current.end += pm_strspn_whitespace(parser->current.end, parser->end - parser->current.end);
const uint8_t *extra_codepoints_start = NULL; const uint8_t *extra_codepoints_start = NULL;
int codepoints_count = 0; int codepoints_count = 0;
parser->current.end += pm_strspn_whitespace(parser->current.end, parser->end - parser->current.end);
while ((parser->current.end < parser->end) && (*parser->current.end != '}')) { while ((parser->current.end < parser->end) && (*parser->current.end != '}')) {
const uint8_t *unicode_start = parser->current.end; const uint8_t *unicode_start = parser->current.end;
size_t hexadecimal_length = pm_strspn_hexadecimal_digit(parser->current.end, parser->end - parser->current.end); size_t hexadecimal_length = pm_strspn_hexadecimal_digit(parser->current.end, parser->end - parser->current.end);
@ -6303,7 +6304,7 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, uint8_t flags) {
switch (peeked) { switch (peeked) {
case '?': case '?':
parser->current.end++; parser->current.end++;
pm_buffer_append_u8(buffer, escape_byte(0x7f, flags | PM_ESCAPE_FLAG_CONTROL)); pm_buffer_append_u8(buffer, escape_byte(0x7f, flags));
return; return;
case '\\': case '\\':
if (flags & PM_ESCAPE_FLAG_CONTROL) { if (flags & PM_ESCAPE_FLAG_CONTROL) {
@ -6336,7 +6337,7 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, uint8_t flags) {
switch (peeked) { switch (peeked) {
case '?': case '?':
parser->current.end++; parser->current.end++;
pm_buffer_append_u8(buffer, escape_byte(0x7f, flags | PM_ESCAPE_FLAG_CONTROL)); pm_buffer_append_u8(buffer, escape_byte(0x7f, flags));
return; return;
case '\\': case '\\':
if (flags & PM_ESCAPE_FLAG_CONTROL) { if (flags & PM_ESCAPE_FLAG_CONTROL) {
@ -6366,28 +6367,24 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, uint8_t flags) {
parser->current.end++; parser->current.end++;
uint8_t peeked = peek(parser); uint8_t peeked = peek(parser);
switch (peeked) { if (peeked == '\\') {
case '?': if (flags & PM_ESCAPE_FLAG_META) {
parser->current.end++; pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_META_REPEAT);
pm_buffer_append_u8(buffer, escape_byte(0x7f, flags | PM_ESCAPE_FLAG_META));
return;
case '\\':
if (flags & PM_ESCAPE_FLAG_META) {
pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_META_REPEAT);
return;
}
parser->current.end++;
escape_read(parser, buffer, flags | PM_ESCAPE_FLAG_META);
return;
default:
if (!char_is_ascii_printable(peeked)) {
pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_META);
return;
}
parser->current.end++;
pm_buffer_append_u8(buffer, escape_byte(peeked, flags | PM_ESCAPE_FLAG_META));
return; return;
}
parser->current.end++;
escape_read(parser, buffer, flags | PM_ESCAPE_FLAG_META);
return;
} }
if (!char_is_ascii_printable(peeked)) {
pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_META);
return;
}
parser->current.end++;
pm_buffer_append_u8(buffer, escape_byte(peeked, flags | PM_ESCAPE_FLAG_META));
return;
} }
default: { default: {
if (parser->current.end < parser->end) { if (parser->current.end < parser->end) {
@ -7873,7 +7870,7 @@ parser_lex(pm_parser_t *parser) {
// and find the next breakpoint. // and find the next breakpoint.
if (*breakpoint == '\\') { if (*breakpoint == '\\') {
pm_unescape_type_t unescape_type = lex_mode->as.list.interpolation ? PM_UNESCAPE_ALL : PM_UNESCAPE_MINIMAL; pm_unescape_type_t unescape_type = lex_mode->as.list.interpolation ? PM_UNESCAPE_ALL : PM_UNESCAPE_MINIMAL;
size_t difference = pm_unescape_calculate_difference(parser, breakpoint, unescape_type, false); size_t difference = pm_unescape_calculate_difference(parser, breakpoint, unescape_type);
if (difference == 0) { if (difference == 0) {
// we're at the end of the file // we're at the end of the file
breakpoint = NULL; breakpoint = NULL;
@ -8010,7 +8007,7 @@ parser_lex(pm_parser_t *parser) {
// literally. In this case we'll skip past the next character // literally. In this case we'll skip past the next character
// and find the next breakpoint. // and find the next breakpoint.
if (*breakpoint == '\\') { if (*breakpoint == '\\') {
size_t difference = pm_unescape_calculate_difference(parser, breakpoint, PM_UNESCAPE_ALL, false); size_t difference = pm_unescape_calculate_difference(parser, breakpoint, PM_UNESCAPE_ALL);
if (difference == 0) { if (difference == 0) {
// we're at the end of the file // we're at the end of the file
breakpoint = NULL; breakpoint = NULL;
@ -8165,7 +8162,7 @@ parser_lex(pm_parser_t *parser) {
// literally. In this case we'll skip past the next character and // literally. In this case we'll skip past the next character and
// find the next breakpoint. // find the next breakpoint.
pm_unescape_type_t unescape_type = parser->lex_modes.current->as.string.interpolation ? PM_UNESCAPE_ALL : PM_UNESCAPE_MINIMAL; pm_unescape_type_t unescape_type = parser->lex_modes.current->as.string.interpolation ? PM_UNESCAPE_ALL : PM_UNESCAPE_MINIMAL;
size_t difference = pm_unescape_calculate_difference(parser, breakpoint, unescape_type, false); size_t difference = pm_unescape_calculate_difference(parser, breakpoint, unescape_type);
if (difference == 0) { if (difference == 0) {
// we're at the end of the file // we're at the end of the file
breakpoint = NULL; breakpoint = NULL;
@ -8341,7 +8338,7 @@ parser_lex(pm_parser_t *parser) {
breakpoint += eol_length; breakpoint += eol_length;
} else { } else {
pm_unescape_type_t unescape_type = (quote == PM_HEREDOC_QUOTE_SINGLE) ? PM_UNESCAPE_MINIMAL : PM_UNESCAPE_ALL; pm_unescape_type_t unescape_type = (quote == PM_HEREDOC_QUOTE_SINGLE) ? PM_UNESCAPE_MINIMAL : PM_UNESCAPE_ALL;
size_t difference = pm_unescape_calculate_difference(parser, breakpoint, unescape_type, false); size_t difference = pm_unescape_calculate_difference(parser, breakpoint, unescape_type);
if (difference == 0) { if (difference == 0) {
// we're at the end of the file // we're at the end of the file
breakpoint = NULL; breakpoint = NULL;

View file

@ -455,8 +455,8 @@ unescape(
// \c\M-x same as above // \c\M-x same as above
// \c? or \C-? delete, ASCII 7Fh (DEL) // \c? or \C-? delete, ASCII 7Fh (DEL)
// //
static void PRISM_EXPORTED_FUNCTION void
pm_unescape_manipulate_string_or_char_literal(pm_parser_t *parser, pm_string_t *string, pm_unescape_type_t unescape_type, bool expect_single_codepoint) { pm_unescape_manipulate_string(pm_parser_t *parser, pm_string_t *string, pm_unescape_type_t unescape_type) {
if (unescape_type == PM_UNESCAPE_NONE) { if (unescape_type == PM_UNESCAPE_NONE) {
// If we're not unescaping then we can reference the source directly. // If we're not unescaping then we can reference the source directly.
return; return;
@ -529,12 +529,7 @@ pm_unescape_manipulate_string_or_char_literal(pm_parser_t *parser, pm_string_t *
// handle all of the different unescapes. // handle all of the different unescapes.
assert(unescape_type == PM_UNESCAPE_ALL); assert(unescape_type == PM_UNESCAPE_ALL);
uint8_t flags = PM_UNESCAPE_FLAG_NONE; cursor = unescape(parser, dest, &dest_length, backslash, end, PM_UNESCAPE_FLAG_NONE, &parser->error_list);
if (expect_single_codepoint) {
flags |= PM_UNESCAPE_FLAG_EXPECT_SINGLE;
}
cursor = unescape(parser, dest, &dest_length, backslash, end, flags, &parser->error_list);
break; break;
} }
@ -562,21 +557,11 @@ pm_unescape_manipulate_string_or_char_literal(pm_parser_t *parser, pm_string_t *
pm_string_owned_init(string, allocated, dest_length + ((size_t) (end - cursor))); pm_string_owned_init(string, allocated, dest_length + ((size_t) (end - cursor)));
} }
PRISM_EXPORTED_FUNCTION void
pm_unescape_manipulate_string(pm_parser_t *parser, pm_string_t *string, pm_unescape_type_t unescape_type) {
pm_unescape_manipulate_string_or_char_literal(parser, string, unescape_type, false);
}
void
pm_unescape_manipulate_char_literal(pm_parser_t *parser, pm_string_t *string, pm_unescape_type_t unescape_type) {
pm_unescape_manipulate_string_or_char_literal(parser, string, unescape_type, true);
}
// This function is similar to pm_unescape_manipulate_string, except it doesn't // This function is similar to pm_unescape_manipulate_string, except it doesn't
// actually perform any string manipulations. Instead, it calculates how long // actually perform any string manipulations. Instead, it calculates how long
// the unescaped character is, and returns that value // the unescaped character is, and returns that value
size_t size_t
pm_unescape_calculate_difference(pm_parser_t *parser, const uint8_t *backslash, pm_unescape_type_t unescape_type, bool expect_single_codepoint) { pm_unescape_calculate_difference(pm_parser_t *parser, const uint8_t *backslash, pm_unescape_type_t unescape_type) {
assert(unescape_type != PM_UNESCAPE_NONE); assert(unescape_type != PM_UNESCAPE_NONE);
if (backslash + 1 >= parser->end) { if (backslash + 1 >= parser->end) {
@ -605,12 +590,7 @@ pm_unescape_calculate_difference(pm_parser_t *parser, const uint8_t *backslash,
// handle all of the different unescapes. // handle all of the different unescapes.
assert(unescape_type == PM_UNESCAPE_ALL); assert(unescape_type == PM_UNESCAPE_ALL);
uint8_t flags = PM_UNESCAPE_FLAG_NONE; const uint8_t *cursor = unescape(parser, NULL, 0, backslash, parser->end, PM_UNESCAPE_FLAG_NONE, NULL);
if (expect_single_codepoint) {
flags |= PM_UNESCAPE_FLAG_EXPECT_SINGLE;
}
const uint8_t *cursor = unescape(parser, NULL, 0, backslash, parser->end, flags, NULL);
assert(cursor > backslash); assert(cursor > backslash);
return (size_t) (cursor - backslash); return (size_t) (cursor - backslash);

View file

@ -35,7 +35,6 @@ typedef enum {
// Unescape the contents of the given token into the given string using the given unescape mode. // Unescape the contents of the given token into the given string using the given unescape mode.
PRISM_EXPORTED_FUNCTION void pm_unescape_manipulate_string(pm_parser_t *parser, pm_string_t *string, pm_unescape_type_t unescape_type); PRISM_EXPORTED_FUNCTION void pm_unescape_manipulate_string(pm_parser_t *parser, pm_string_t *string, pm_unescape_type_t unescape_type);
void pm_unescape_manipulate_char_literal(pm_parser_t *parser, pm_string_t *string, pm_unescape_type_t unescape_type);
// Accepts a source string and a type of unescaping and returns the unescaped version. // Accepts a source string and a type of unescaping and returns the unescaped version.
// The caller must pm_string_free(result); after calling this function. // The caller must pm_string_free(result); after calling this function.
@ -43,6 +42,6 @@ PRISM_EXPORTED_FUNCTION bool pm_unescape_string(const uint8_t *start, size_t len
// Returns the number of bytes that encompass the first escape sequence in the // Returns the number of bytes that encompass the first escape sequence in the
// given string. // given string.
size_t pm_unescape_calculate_difference(pm_parser_t *parser, const uint8_t *value, pm_unescape_type_t unescape_type, bool expect_single_codepoint); size_t pm_unescape_calculate_difference(pm_parser_t *parser, const uint8_t *value, pm_unescape_type_t unescape_type);
#endif #endif

View file

@ -9,22 +9,22 @@ module Prism
module Context module Context
class Base class Base
attr_reader :left, :right attr_reader :left, :right
def initialize(left, right) def initialize(left, right)
@left = left @left = left
@right = right @right = right
end end
def name def name
"#{left}#{right}".delete("\n") "#{left}#{right}".delete("\n")
end end
private private
def code(escape) def code(escape)
"#{left}\\#{escape}#{right}".b "#{left}\\#{escape}#{right}".b
end end
def ruby(escape) def ruby(escape)
previous, $VERBOSE = $VERBOSE, nil previous, $VERBOSE = $VERBOSE, nil
@ -36,37 +36,37 @@ module Prism
$VERBOSE = previous $VERBOSE = previous
end end
end end
def prism(escape) def prism(escape)
result = Prism.parse(code(escape)) result = Prism.parse(code(escape))
if result.success? if result.success?
yield result.value.statements.body.first yield result.value.statements.body.first
else else
:error :error
end end
end end
def `(command) def `(command)
command command
end end
end end
class List < Base class List < Base
def ruby_result(escape) = ruby(escape) { |value| value.first.to_s } def ruby_result(escape) = ruby(escape) { |value| value.first.to_s }
def prism_result(escape) = prism(escape) { |node| node.elements.first.unescaped } def prism_result(escape) = prism(escape) { |node| node.elements.first.unescaped }
end end
class Symbol < Base class Symbol < Base
def ruby_result(escape) = ruby(escape, &:to_s) def ruby_result(escape) = ruby(escape, &:to_s)
def prism_result(escape) = prism(escape, &:unescaped) def prism_result(escape) = prism(escape, &:unescaped)
end end
class String < Base class String < Base
def ruby_result(escape) = ruby(escape, &:itself) def ruby_result(escape) = ruby(escape, &:itself)
def prism_result(escape) = prism(escape, &:unescaped) def prism_result(escape) = prism(escape, &:unescaped)
end end
class RegExp < Base class RegExp < Base
def ruby_result(escape) = ruby(escape, &:source) def ruby_result(escape) = ruby(escape, &:source)
def prism_result(escape) = prism(escape, &:unescaped) def prism_result(escape) = prism(escape, &:unescaped)
@ -92,13 +92,13 @@ module Prism
escapes = [*ascii, *ascii8, *octal, *hex2, *hex4, *hex6, *ctrls] escapes = [*ascii, *ascii8, *octal, *hex2, *hex4, *hex6, *ctrls]
contexts = [ contexts = [
[Context::String.new("?", ""), [*ascii, *octal]], #, *hex2]], [Context::String.new("?", ""), escapes],
[Context::String.new("'", "'"), escapes], # [Context::String.new("'", "'"), escapes],
[Context::String.new("\"", "\""), escapes], # [Context::String.new("\"", "\""), escapes],
# [Context::String.new("%q[", "]"), escapes], # [Context::String.new("%q[", "]"), escapes],
[Context::String.new("%Q[", "]"), escapes], # [Context::String.new("%Q[", "]"), escapes],
[Context::String.new("%[", "]"), escapes], # [Context::String.new("%[", "]"), escapes],
[Context::String.new("`", "`"), escapes], # [Context::String.new("`", "`"), escapes],
# [Context::String.new("<<~H\n", "\nH"), escapes], # [Context::String.new("<<~H\n", "\nH"), escapes],
# [Context::String.new("<<~'H'\n", "\nH"), escapes], # [Context::String.new("<<~'H'\n", "\nH"), escapes],
# [Context::String.new("<<~\"H\"\n", "\nH"), escapes], # [Context::String.new("<<~\"H\"\n", "\nH"), escapes],
@ -109,16 +109,14 @@ module Prism
# [Context::List.new("%I[", "]"), escapes], # [Context::List.new("%I[", "]"), escapes],
# [Context::Symbol.new("%s[", "]"), escapes], # [Context::Symbol.new("%s[", "]"), escapes],
# [Context::Symbol.new(":'", "'"), escapes], # [Context::Symbol.new(":'", "'"), escapes],
[Context::Symbol.new(":\"", "\""), escapes], # [Context::Symbol.new(":\"", "\""), escapes],
# [Context::RegExp.new("/", "/"), escapes], # [Context::RegExp.new("/", "/"), escapes],
# [Context::RegExp.new("%r[", "]"), escapes] # [Context::RegExp.new("%r[", "]"), escapes]
] ]
known_failures = [["?", "\n"]]
contexts.each do |(context, escapes)| contexts.each do |(context, escapes)|
escapes.each do |escape| escapes.each do |escape|
next if known_failures.include?([context.name, escape]) next if context.name == "?" && escape == "\xFF".b # wat?
define_method(:"test_#{context.name}_#{escape.inspect}") do define_method(:"test_#{context.name}_#{escape.inspect}") do
assert_unescape(context, escape) assert_unescape(context, escape)