diff --git a/prism/prism.c b/prism/prism.c index 17d6fb6299..2cbf664b5d 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -14570,6 +14570,50 @@ parse_call_operator_write(pm_parser_t *parser, pm_call_node_t *call_node, const } } +// Potentially change a =~ with a regular expression with named captures into a +// match write node. +static pm_node_t * +parse_regular_expression_named_captures(pm_parser_t *parser, const pm_string_t *content, pm_call_node_t *call) { + pm_string_list_t named_captures; + pm_string_list_init(&named_captures); + + pm_node_t *result; + if (pm_regexp_named_capture_group_names(pm_string_source(content), pm_string_length(content), &named_captures, parser->encoding_changed, &parser->encoding) && (named_captures.length > 0)) { + pm_match_write_node_t *match = pm_match_write_node_create(parser, call); + + for (size_t index = 0; index < named_captures.length; index++) { + pm_string_t *name = &named_captures.strings[index]; + pm_constant_id_t local; + + if (content->type == PM_STRING_SHARED) { + // If the unescaped string is a slice of the source, + // then we can copy the names directly. The pointers + // will line up. + local = pm_parser_local_add_location(parser, name->source, name->source + name->length); + } else { + // Otherwise, the name is a slice of the malloc-ed + // owned string, in which case we need to copy it + // out into a new string. + size_t length = pm_string_length(name); + + void *memory = malloc(length); + memcpy(memory, pm_string_source(name), length); + + local = pm_parser_local_add_owned(parser, (const uint8_t *) memory, length); + } + + pm_constant_id_list_append(&match->locals, local); + } + + result = (pm_node_t *) match; + } else { + result = (pm_node_t *) call; + } + + pm_string_list_free(&named_captures); + return result; +} + static inline pm_node_t * parse_expression_infix(pm_parser_t *parser, pm_node_t *node, pm_binding_power_t previous_binding_power, pm_binding_power_t binding_power) { pm_token_t token = parser->current; @@ -14995,42 +15039,51 @@ parse_expression_infix(pm_parser_t *parser, pm_node_t *node, pm_binding_power_t // If the receiver of this =~ is a regular expression node, then we // need to introduce local variables for it based on its named // capture groups. - if (PM_NODE_TYPE_P(node, PM_REGULAR_EXPRESSION_NODE)) { - pm_string_list_t named_captures; - pm_string_list_init(&named_captures); + if (PM_NODE_TYPE_P(node, PM_INTERPOLATED_REGULAR_EXPRESSION_NODE)) { + // It's possible to have an interpolated regular expression node + // that only contains strings. This is because it can be split + // up by a heredoc. In this case we need to concat the unescaped + // strings together and then parse them as a regular expression. + pm_node_list_t *parts = &((pm_interpolated_regular_expression_node_t *) node)->parts; - const pm_string_t *unescaped = &((pm_regular_expression_node_t *) node)->unescaped; - if (pm_regexp_named_capture_group_names(pm_string_source(unescaped), pm_string_length(unescaped), &named_captures, parser->encoding_changed, &parser->encoding) && (named_captures.length > 0)) { - pm_match_write_node_t *match = pm_match_write_node_create(parser, call); + bool interpolated = false; + size_t total_length = 0; - for (size_t index = 0; index < named_captures.length; index++) { - pm_string_t *name = &named_captures.strings[index]; - pm_constant_id_t local; + for (size_t index = 0; index < parts->size; index++) { + pm_node_t *part = parts->nodes[index]; - if (unescaped->type == PM_STRING_SHARED) { - // If the unescaped string is a slice of the source, - // then we can copy the names directly. The pointers - // will line up. - local = pm_parser_local_add_location(parser, name->source, name->source + name->length); - } else { - // Otherwise, the name is a slice of the malloc-ed - // owned string, in which case we need to copy it - // out into a new string. - size_t length = pm_string_length(name); - - void *memory = malloc(length); - memcpy(memory, pm_string_source(name), length); - - local = pm_parser_local_add_owned(parser, (const uint8_t *) memory, length); - } - - pm_constant_id_list_append(&match->locals, local); + if (PM_NODE_TYPE_P(part, PM_STRING_NODE)) { + total_length += pm_string_length(&((pm_string_node_t *) part)->unescaped); + } else { + interpolated = true; + break; } - - result = (pm_node_t *) match; } - pm_string_list_free(&named_captures); + if (!interpolated) { + void *memory = malloc(total_length); + if (!memory) abort(); + + uint8_t *cursor = memory; + for (size_t index = 0; index < parts->size; index++) { + pm_string_t *unescaped = &((pm_string_node_t *) parts->nodes[index])->unescaped; + size_t length = pm_string_length(unescaped); + + memcpy(cursor, pm_string_source(unescaped), length); + cursor += length; + } + + pm_string_t owned; + pm_string_owned_init(&owned, (uint8_t *) memory, total_length); + + result = parse_regular_expression_named_captures(parser, &owned, call); + pm_string_free(&owned); + } + } else if (PM_NODE_TYPE_P(node, PM_REGULAR_EXPRESSION_NODE)) { + // If we have a regular expression node, then we can just parse + // the named captures directly off the unescaped string. + const pm_string_t *content = &((pm_regular_expression_node_t *) node)->unescaped; + result = parse_regular_expression_named_captures(parser, content, call); } return result; diff --git a/prism/regexp.c b/prism/regexp.c index c227c7b4c1..3462c846ce 100644 --- a/prism/regexp.c +++ b/prism/regexp.c @@ -188,6 +188,8 @@ pm_regexp_parse_range_quantifier(pm_regexp_parser_t *parser) { // ; static bool pm_regexp_parse_quantifier(pm_regexp_parser_t *parser) { + if (pm_regexp_char_is_eof(parser)) return true; + switch (*parser->cursor) { case '*': case '+': diff --git a/test/prism/fixtures/spanning_heredoc.txt b/test/prism/fixtures/spanning_heredoc.txt index d88e0e4be1..c1b9ec72f4 100644 --- a/test/prism/fixtures/spanning_heredoc.txt +++ b/test/prism/fixtures/spanning_heredoc.txt @@ -49,3 +49,7 @@ pp <<-A, %I[p\ o A p] + +<)/ =~ '' diff --git a/test/prism/snapshots/spanning_heredoc.txt b/test/prism/snapshots/spanning_heredoc.txt index e568dc2572..f28aeb815a 100644 --- a/test/prism/snapshots/spanning_heredoc.txt +++ b/test/prism/snapshots/spanning_heredoc.txt @@ -1,8 +1,8 @@ -@ ProgramNode (location: (4,0)-(51,2)) -├── locals: [] +@ ProgramNode (location: (4,0)-(55,13)) +├── locals: [:a] └── statements: - @ StatementsNode (location: (4,0)-(51,2)) - └── body: (length: 8) + @ StatementsNode (location: (4,0)-(55,13)) + └── body: (length: 10) ├── @ CallNode (location: (4,0)-(7,7)) │ ├── receiver: ∅ │ ├── call_operator_loc: ∅ @@ -270,41 +270,86 @@ │ ├── block: ∅ │ ├── flags: ∅ │ └── name: :pp - └── @ CallNode (location: (48,0)-(51,2)) - ├── receiver: ∅ - ├── call_operator_loc: ∅ - ├── message_loc: (48,0)-(48,2) = "pp" - ├── opening_loc: ∅ - ├── arguments: - │ @ ArgumentsNode (location: (48,3)-(51,2)) - │ ├── arguments: (length: 2) - │ │ ├── @ StringNode (location: (48,3)-(48,7)) - │ │ │ ├── flags: ∅ - │ │ │ ├── opening_loc: (48,3)-(48,7) = "<<-A" - │ │ │ ├── content_loc: (49,0)-(50,0) = "o\n" - │ │ │ ├── closing_loc: (50,0)-(51,0) = "A\n" - │ │ │ └── unescaped: "o\n" - │ │ └── @ ArrayNode (location: (48,9)-(51,2)) - │ │ ├── elements: (length: 1) - │ │ │ └── @ InterpolatedSymbolNode (location: (48,12)-(48,14)) - │ │ │ ├── opening_loc: ∅ - │ │ │ ├── parts: (length: 2) - │ │ │ │ ├── @ SymbolNode (location: (48,12)-(48,14)) - │ │ │ │ │ ├── opening_loc: ∅ - │ │ │ │ │ ├── value_loc: (48,12)-(48,14) = "p\\" - │ │ │ │ │ ├── closing_loc: ∅ - │ │ │ │ │ └── unescaped: "p\n" - │ │ │ │ └── @ StringNode (location: (48,12)-(48,14)) - │ │ │ │ ├── flags: ∅ - │ │ │ │ ├── opening_loc: ∅ - │ │ │ │ ├── content_loc: (48,12)-(48,14) = "p\\" - │ │ │ │ ├── closing_loc: ∅ - │ │ │ │ └── unescaped: "p" - │ │ │ └── closing_loc: ∅ - │ │ ├── opening_loc: (48,9)-(48,12) = "%I[" - │ │ └── closing_loc: (51,1)-(51,2) = "]" - │ └── flags: ∅ - ├── closing_loc: ∅ - ├── block: ∅ - ├── flags: ∅ - └── name: :pp + ├── @ CallNode (location: (48,0)-(51,2)) + │ ├── receiver: ∅ + │ ├── call_operator_loc: ∅ + │ ├── message_loc: (48,0)-(48,2) = "pp" + │ ├── opening_loc: ∅ + │ ├── arguments: + │ │ @ ArgumentsNode (location: (48,3)-(51,2)) + │ │ ├── arguments: (length: 2) + │ │ │ ├── @ StringNode (location: (48,3)-(48,7)) + │ │ │ │ ├── flags: ∅ + │ │ │ │ ├── opening_loc: (48,3)-(48,7) = "<<-A" + │ │ │ │ ├── content_loc: (49,0)-(50,0) = "o\n" + │ │ │ │ ├── closing_loc: (50,0)-(51,0) = "A\n" + │ │ │ │ └── unescaped: "o\n" + │ │ │ └── @ ArrayNode (location: (48,9)-(51,2)) + │ │ │ ├── elements: (length: 1) + │ │ │ │ └── @ InterpolatedSymbolNode (location: (48,12)-(48,14)) + │ │ │ │ ├── opening_loc: ∅ + │ │ │ │ ├── parts: (length: 2) + │ │ │ │ │ ├── @ SymbolNode (location: (48,12)-(48,14)) + │ │ │ │ │ │ ├── opening_loc: ∅ + │ │ │ │ │ │ ├── value_loc: (48,12)-(48,14) = "p\\" + │ │ │ │ │ │ ├── closing_loc: ∅ + │ │ │ │ │ │ └── unescaped: "p\n" + │ │ │ │ │ └── @ StringNode (location: (48,12)-(48,14)) + │ │ │ │ │ ├── flags: ∅ + │ │ │ │ │ ├── opening_loc: ∅ + │ │ │ │ │ ├── content_loc: (48,12)-(48,14) = "p\\" + │ │ │ │ │ ├── closing_loc: ∅ + │ │ │ │ │ └── unescaped: "p" + │ │ │ │ └── closing_loc: ∅ + │ │ │ ├── opening_loc: (48,9)-(48,12) = "%I[" + │ │ │ └── closing_loc: (51,1)-(51,2) = "]" + │ │ └── flags: ∅ + │ ├── closing_loc: ∅ + │ ├── block: ∅ + │ ├── flags: ∅ + │ └── name: :pp + ├── @ StringNode (location: (53,0)-(53,3)) + │ ├── flags: ∅ + │ ├── opening_loc: (53,0)-(53,3) = "<)" + │ │ ├── closing_loc: (55,6)-(55,7) = "/" + │ │ └── flags: ∅ + │ ├── call_operator_loc: ∅ + │ ├── message_loc: (55,8)-(55,10) = "=~" + │ ├── opening_loc: ∅ + │ ├── arguments: + │ │ @ ArgumentsNode (location: (55,11)-(55,13)) + │ │ ├── arguments: (length: 1) + │ │ │ └── @ StringNode (location: (55,11)-(55,13)) + │ │ │ ├── flags: ∅ + │ │ │ ├── opening_loc: (55,11)-(55,12) = "'" + │ │ │ ├── content_loc: (55,12)-(55,12) = "" + │ │ │ ├── closing_loc: (55,12)-(55,13) = "'" + │ │ │ └── unescaped: "" + │ │ └── flags: ∅ + │ ├── closing_loc: ∅ + │ ├── block: ∅ + │ ├── flags: ∅ + │ └── name: :=~ + └── locals: [:a]