diff --git a/prism/templates/lib/prism/node.rb.erb b/prism/templates/lib/prism/node.rb.erb index 6b5a285315..f869a841c5 100644 --- a/prism/templates/lib/prism/node.rb.erb +++ b/prism/templates/lib/prism/node.rb.erb @@ -219,10 +219,10 @@ module Prism def deconstruct_keys(keys) { <%= (node.fields.map { |field| "#{field.name}: #{field.name}" } + ["location: location"]).join(", ") %> } end - <%- node.fields.each do |field| -%> + <%- if field.comment.nil? -%> - # <%= "private " if field.is_a?(Prism::Template::FlagsField) %>attr_reader <%= field.name %>: <%= field.rbs_class %> + # <%= "protected " if field.is_a?(Prism::Template::FlagsField) %>attr_reader <%= field.name %>: <%= field.rbs_class %> <%- else -%> <%- field.each_comment_line do |line| -%> #<%= line %> @@ -248,9 +248,8 @@ module Prism end end <%- else -%> - attr_reader :<%= field.name -%><%= "\n private :#{field.name}" if field.is_a?(Prism::Template::FlagsField) %> + attr_reader :<%= field.name -%><%= "\n protected :#{field.name}" if field.is_a?(Prism::Template::FlagsField) %> <%- end -%> - <%- end -%> <%- node.fields.each do |field| -%> <%- case field -%> @@ -349,6 +348,22 @@ module Prism def self.type :<%= node.human %> end + + # Implements case-equality for the node. This is effectively == but without + # comparing the value of locations. Locations are checked only for presence. + def ===(other) + other.is_a?(<%= node.name %>)<%= " &&" if node.fields.any? %> + <%- node.fields.each_with_index do |field, index| -%> + <%- if field.is_a?(Prism::Template::LocationField) || field.is_a?(Prism::Template::OptionalLocationField) -%> + (<%= field.name %>.nil? == other.<%= field.name %>.nil?)<%= " &&" if index != node.fields.length - 1 %> + <%- elsif field.is_a?(Prism::Template::NodeListField) || field.is_a?(Prism::Template::ConstantListField) -%> + (<%= field.name %>.length == other.<%= field.name %>.length) && + <%= field.name %>.zip(other.<%= field.name %>).all? { |left, right| left === right }<%= " &&" if index != node.fields.length - 1 %> + <%- else -%> + (<%= field.name %> === other.<%= field.name %>)<%= " &&" if index != node.fields.length - 1 %> + <%- end -%> + <%- end -%> + end end <%- end -%> <%- flags.each_with_index do |flag, flag_index| -%> diff --git a/test/prism/ruby_api_test.rb b/test/prism/ruby_api_test.rb index 6418887147..bf493666d2 100644 --- a/test/prism/ruby_api_test.rb +++ b/test/prism/ruby_api_test.rb @@ -244,6 +244,21 @@ module Prism assert_equal 16, base[parse_expression("0x1")] end + def test_node_equality + assert_operator parse_expression("1"), :===, parse_expression("1") + assert_operator Prism.parse("1").value, :===, Prism.parse("1").value + + complex_source = "class Something; @var = something.else { _1 }; end" + assert_operator parse_expression(complex_source), :===, parse_expression(complex_source) + + refute_operator parse_expression("1"), :===, parse_expression("2") + refute_operator parse_expression("1"), :===, parse_expression("0x1") + + complex_source_1 = "class Something; @var = something.else { _1 }; end" + complex_source_2 = "class Something; @var = something.else { _2 }; end" + refute_operator parse_expression(complex_source_1), :===, parse_expression(complex_source_2) + end + private def parse_expression(source)