diff --git a/prism/prism.c b/prism/prism.c
index 17d6fb6299..2cbf664b5d 100644
--- a/prism/prism.c
+++ b/prism/prism.c
@@ -14570,6 +14570,50 @@ parse_call_operator_write(pm_parser_t *parser, pm_call_node_t *call_node, const
}
}
+// Potentially change a =~ with a regular expression with named captures into a
+// match write node.
+static pm_node_t *
+parse_regular_expression_named_captures(pm_parser_t *parser, const pm_string_t *content, pm_call_node_t *call) {
+ pm_string_list_t named_captures;
+ pm_string_list_init(&named_captures);
+
+ pm_node_t *result;
+ if (pm_regexp_named_capture_group_names(pm_string_source(content), pm_string_length(content), &named_captures, parser->encoding_changed, &parser->encoding) && (named_captures.length > 0)) {
+ pm_match_write_node_t *match = pm_match_write_node_create(parser, call);
+
+ for (size_t index = 0; index < named_captures.length; index++) {
+ pm_string_t *name = &named_captures.strings[index];
+ pm_constant_id_t local;
+
+ if (content->type == PM_STRING_SHARED) {
+ // If the unescaped string is a slice of the source,
+ // then we can copy the names directly. The pointers
+ // will line up.
+ local = pm_parser_local_add_location(parser, name->source, name->source + name->length);
+ } else {
+ // Otherwise, the name is a slice of the malloc-ed
+ // owned string, in which case we need to copy it
+ // out into a new string.
+ size_t length = pm_string_length(name);
+
+ void *memory = malloc(length);
+ memcpy(memory, pm_string_source(name), length);
+
+ local = pm_parser_local_add_owned(parser, (const uint8_t *) memory, length);
+ }
+
+ pm_constant_id_list_append(&match->locals, local);
+ }
+
+ result = (pm_node_t *) match;
+ } else {
+ result = (pm_node_t *) call;
+ }
+
+ pm_string_list_free(&named_captures);
+ return result;
+}
+
static inline pm_node_t *
parse_expression_infix(pm_parser_t *parser, pm_node_t *node, pm_binding_power_t previous_binding_power, pm_binding_power_t binding_power) {
pm_token_t token = parser->current;
@@ -14995,42 +15039,51 @@ parse_expression_infix(pm_parser_t *parser, pm_node_t *node, pm_binding_power_t
// If the receiver of this =~ is a regular expression node, then we
// need to introduce local variables for it based on its named
// capture groups.
- if (PM_NODE_TYPE_P(node, PM_REGULAR_EXPRESSION_NODE)) {
- pm_string_list_t named_captures;
- pm_string_list_init(&named_captures);
+ if (PM_NODE_TYPE_P(node, PM_INTERPOLATED_REGULAR_EXPRESSION_NODE)) {
+ // It's possible to have an interpolated regular expression node
+ // that only contains strings. This is because it can be split
+ // up by a heredoc. In this case we need to concat the unescaped
+ // strings together and then parse them as a regular expression.
+ pm_node_list_t *parts = &((pm_interpolated_regular_expression_node_t *) node)->parts;
- const pm_string_t *unescaped = &((pm_regular_expression_node_t *) node)->unescaped;
- if (pm_regexp_named_capture_group_names(pm_string_source(unescaped), pm_string_length(unescaped), &named_captures, parser->encoding_changed, &parser->encoding) && (named_captures.length > 0)) {
- pm_match_write_node_t *match = pm_match_write_node_create(parser, call);
+ bool interpolated = false;
+ size_t total_length = 0;
- for (size_t index = 0; index < named_captures.length; index++) {
- pm_string_t *name = &named_captures.strings[index];
- pm_constant_id_t local;
+ for (size_t index = 0; index < parts->size; index++) {
+ pm_node_t *part = parts->nodes[index];
- if (unescaped->type == PM_STRING_SHARED) {
- // If the unescaped string is a slice of the source,
- // then we can copy the names directly. The pointers
- // will line up.
- local = pm_parser_local_add_location(parser, name->source, name->source + name->length);
- } else {
- // Otherwise, the name is a slice of the malloc-ed
- // owned string, in which case we need to copy it
- // out into a new string.
- size_t length = pm_string_length(name);
-
- void *memory = malloc(length);
- memcpy(memory, pm_string_source(name), length);
-
- local = pm_parser_local_add_owned(parser, (const uint8_t *) memory, length);
- }
-
- pm_constant_id_list_append(&match->locals, local);
+ if (PM_NODE_TYPE_P(part, PM_STRING_NODE)) {
+ total_length += pm_string_length(&((pm_string_node_t *) part)->unescaped);
+ } else {
+ interpolated = true;
+ break;
}
-
- result = (pm_node_t *) match;
}
- pm_string_list_free(&named_captures);
+ if (!interpolated) {
+ void *memory = malloc(total_length);
+ if (!memory) abort();
+
+ uint8_t *cursor = memory;
+ for (size_t index = 0; index < parts->size; index++) {
+ pm_string_t *unescaped = &((pm_string_node_t *) parts->nodes[index])->unescaped;
+ size_t length = pm_string_length(unescaped);
+
+ memcpy(cursor, pm_string_source(unescaped), length);
+ cursor += length;
+ }
+
+ pm_string_t owned;
+ pm_string_owned_init(&owned, (uint8_t *) memory, total_length);
+
+ result = parse_regular_expression_named_captures(parser, &owned, call);
+ pm_string_free(&owned);
+ }
+ } else if (PM_NODE_TYPE_P(node, PM_REGULAR_EXPRESSION_NODE)) {
+ // If we have a regular expression node, then we can just parse
+ // the named captures directly off the unescaped string.
+ const pm_string_t *content = &((pm_regular_expression_node_t *) node)->unescaped;
+ result = parse_regular_expression_named_captures(parser, content, call);
}
return result;
diff --git a/prism/regexp.c b/prism/regexp.c
index c227c7b4c1..3462c846ce 100644
--- a/prism/regexp.c
+++ b/prism/regexp.c
@@ -188,6 +188,8 @@ pm_regexp_parse_range_quantifier(pm_regexp_parser_t *parser) {
// ;
static bool
pm_regexp_parse_quantifier(pm_regexp_parser_t *parser) {
+ if (pm_regexp_char_is_eof(parser)) return true;
+
switch (*parser->cursor) {
case '*':
case '+':
diff --git a/test/prism/fixtures/spanning_heredoc.txt b/test/prism/fixtures/spanning_heredoc.txt
index d88e0e4be1..c1b9ec72f4 100644
--- a/test/prism/fixtures/spanning_heredoc.txt
+++ b/test/prism/fixtures/spanning_heredoc.txt
@@ -49,3 +49,7 @@ pp <<-A, %I[p\
o
A
p]
+
+<)/ =~ ''
diff --git a/test/prism/snapshots/spanning_heredoc.txt b/test/prism/snapshots/spanning_heredoc.txt
index e568dc2572..f28aeb815a 100644
--- a/test/prism/snapshots/spanning_heredoc.txt
+++ b/test/prism/snapshots/spanning_heredoc.txt
@@ -1,8 +1,8 @@
-@ ProgramNode (location: (4,0)-(51,2))
-├── locals: []
+@ ProgramNode (location: (4,0)-(55,13))
+├── locals: [:a]
└── statements:
- @ StatementsNode (location: (4,0)-(51,2))
- └── body: (length: 8)
+ @ StatementsNode (location: (4,0)-(55,13))
+ └── body: (length: 10)
├── @ CallNode (location: (4,0)-(7,7))
│ ├── receiver: ∅
│ ├── call_operator_loc: ∅
@@ -270,41 +270,86 @@
│ ├── block: ∅
│ ├── flags: ∅
│ └── name: :pp
- └── @ CallNode (location: (48,0)-(51,2))
- ├── receiver: ∅
- ├── call_operator_loc: ∅
- ├── message_loc: (48,0)-(48,2) = "pp"
- ├── opening_loc: ∅
- ├── arguments:
- │ @ ArgumentsNode (location: (48,3)-(51,2))
- │ ├── arguments: (length: 2)
- │ │ ├── @ StringNode (location: (48,3)-(48,7))
- │ │ │ ├── flags: ∅
- │ │ │ ├── opening_loc: (48,3)-(48,7) = "<<-A"
- │ │ │ ├── content_loc: (49,0)-(50,0) = "o\n"
- │ │ │ ├── closing_loc: (50,0)-(51,0) = "A\n"
- │ │ │ └── unescaped: "o\n"
- │ │ └── @ ArrayNode (location: (48,9)-(51,2))
- │ │ ├── elements: (length: 1)
- │ │ │ └── @ InterpolatedSymbolNode (location: (48,12)-(48,14))
- │ │ │ ├── opening_loc: ∅
- │ │ │ ├── parts: (length: 2)
- │ │ │ │ ├── @ SymbolNode (location: (48,12)-(48,14))
- │ │ │ │ │ ├── opening_loc: ∅
- │ │ │ │ │ ├── value_loc: (48,12)-(48,14) = "p\\"
- │ │ │ │ │ ├── closing_loc: ∅
- │ │ │ │ │ └── unescaped: "p\n"
- │ │ │ │ └── @ StringNode (location: (48,12)-(48,14))
- │ │ │ │ ├── flags: ∅
- │ │ │ │ ├── opening_loc: ∅
- │ │ │ │ ├── content_loc: (48,12)-(48,14) = "p\\"
- │ │ │ │ ├── closing_loc: ∅
- │ │ │ │ └── unescaped: "p"
- │ │ │ └── closing_loc: ∅
- │ │ ├── opening_loc: (48,9)-(48,12) = "%I["
- │ │ └── closing_loc: (51,1)-(51,2) = "]"
- │ └── flags: ∅
- ├── closing_loc: ∅
- ├── block: ∅
- ├── flags: ∅
- └── name: :pp
+ ├── @ CallNode (location: (48,0)-(51,2))
+ │ ├── receiver: ∅
+ │ ├── call_operator_loc: ∅
+ │ ├── message_loc: (48,0)-(48,2) = "pp"
+ │ ├── opening_loc: ∅
+ │ ├── arguments:
+ │ │ @ ArgumentsNode (location: (48,3)-(51,2))
+ │ │ ├── arguments: (length: 2)
+ │ │ │ ├── @ StringNode (location: (48,3)-(48,7))
+ │ │ │ │ ├── flags: ∅
+ │ │ │ │ ├── opening_loc: (48,3)-(48,7) = "<<-A"
+ │ │ │ │ ├── content_loc: (49,0)-(50,0) = "o\n"
+ │ │ │ │ ├── closing_loc: (50,0)-(51,0) = "A\n"
+ │ │ │ │ └── unescaped: "o\n"
+ │ │ │ └── @ ArrayNode (location: (48,9)-(51,2))
+ │ │ │ ├── elements: (length: 1)
+ │ │ │ │ └── @ InterpolatedSymbolNode (location: (48,12)-(48,14))
+ │ │ │ │ ├── opening_loc: ∅
+ │ │ │ │ ├── parts: (length: 2)
+ │ │ │ │ │ ├── @ SymbolNode (location: (48,12)-(48,14))
+ │ │ │ │ │ │ ├── opening_loc: ∅
+ │ │ │ │ │ │ ├── value_loc: (48,12)-(48,14) = "p\\"
+ │ │ │ │ │ │ ├── closing_loc: ∅
+ │ │ │ │ │ │ └── unescaped: "p\n"
+ │ │ │ │ │ └── @ StringNode (location: (48,12)-(48,14))
+ │ │ │ │ │ ├── flags: ∅
+ │ │ │ │ │ ├── opening_loc: ∅
+ │ │ │ │ │ ├── content_loc: (48,12)-(48,14) = "p\\"
+ │ │ │ │ │ ├── closing_loc: ∅
+ │ │ │ │ │ └── unescaped: "p"
+ │ │ │ │ └── closing_loc: ∅
+ │ │ │ ├── opening_loc: (48,9)-(48,12) = "%I["
+ │ │ │ └── closing_loc: (51,1)-(51,2) = "]"
+ │ │ └── flags: ∅
+ │ ├── closing_loc: ∅
+ │ ├── block: ∅
+ │ ├── flags: ∅
+ │ └── name: :pp
+ ├── @ StringNode (location: (53,0)-(53,3))
+ │ ├── flags: ∅
+ │ ├── opening_loc: (53,0)-(53,3) = "<)"
+ │ │ │ ├── closing_loc: ∅
+ │ │ │ └── unescaped: "(?)"
+ │ │ ├── closing_loc: (55,6)-(55,7) = "/"
+ │ │ └── flags: ∅
+ │ ├── call_operator_loc: ∅
+ │ ├── message_loc: (55,8)-(55,10) = "=~"
+ │ ├── opening_loc: ∅
+ │ ├── arguments:
+ │ │ @ ArgumentsNode (location: (55,11)-(55,13))
+ │ │ ├── arguments: (length: 1)
+ │ │ │ └── @ StringNode (location: (55,11)-(55,13))
+ │ │ │ ├── flags: ∅
+ │ │ │ ├── opening_loc: (55,11)-(55,12) = "'"
+ │ │ │ ├── content_loc: (55,12)-(55,12) = ""
+ │ │ │ ├── closing_loc: (55,12)-(55,13) = "'"
+ │ │ │ └── unescaped: ""
+ │ │ └── flags: ∅
+ │ ├── closing_loc: ∅
+ │ ├── block: ∅
+ │ ├── flags: ∅
+ │ └── name: :=~
+ └── locals: [:a]