ruby/yarp/templates/java/org/yarp/Loader.java.erb
2023-08-25 18:20:51 -04:00

293 lines
9.3 KiB
Text

package org.yarp;
import org.yarp.ParseResult;
import java.lang.Short;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
// GENERATED BY <%= File.basename(__FILE__) %>
// @formatter:off
public class Loader {
public static ParseResult load(byte[] serialized, Nodes.Source source) {
return new Loader(serialized, source).load();
}
private static final class ConstantPool {
private final byte[] source;
private final int bufferOffset;
private final byte[][] cache;
ConstantPool(byte[] source, int bufferOffset, int length) {
this.source = source;
this.bufferOffset = bufferOffset;
cache = new byte[length][];
}
byte[] get(ByteBuffer buffer, int oneBasedIndex) {
int index = oneBasedIndex - 1;
byte[] constant = cache[index];
if (constant == null) {
int offset = bufferOffset + index * 8;
int start = buffer.getInt(offset);
int length = buffer.getInt(offset + 4);
constant = new byte[length];
System.arraycopy(source, start, constant, 0, length);
cache[index] = constant;
}
return constant;
}
}
private final ByteBuffer buffer;
private ConstantPool constantPool;
private final Nodes.Source source;
private Loader(byte[] serialized, Nodes.Source source) {
this.buffer = ByteBuffer.wrap(serialized).order(ByteOrder.nativeOrder());
this.source = source;
}
private ParseResult load() {
expect((byte) 'Y');
expect((byte) 'A');
expect((byte) 'R');
expect((byte) 'P');
expect((byte) 0);
expect((byte) 8);
expect((byte) 0);
// This loads the name of the encoding. We don't actually do anything
// with it just yet.
int encodingLength = loadVarInt();
byte[] encodingName = new byte[encodingLength];
buffer.get(encodingName);
ParseResult.Comment[] comments = loadComments();
ParseResult.Error[] errors = loadSyntaxErrors();
ParseResult.Warning[] warnings = loadWarnings();
int constantPoolBufferOffset = buffer.getInt();
int constantPoolLength = loadVarInt();
this.constantPool = new ConstantPool(source.bytes, constantPoolBufferOffset, constantPoolLength);
Nodes.Node node = loadNode();
int left = constantPoolBufferOffset - buffer.position();
if (left != 0) {
throw new Error("Expected to consume all bytes while deserializing but there were " + left + " bytes left");
}
boolean[] newlineMarked = new boolean[1 + source.getLineCount()];
MarkNewlinesVisitor visitor = new MarkNewlinesVisitor(source, newlineMarked);
node.accept(visitor);
return new ParseResult(node, comments, errors, warnings);
}
private byte[] loadEmbeddedString() {
int length = loadVarInt();
byte[] string = new byte[length];
buffer.get(string);
return string;
}
private byte[] loadString() {
switch (buffer.get()) {
case 1:
int start = loadVarInt();
int length = loadVarInt();
byte[] string = new byte[length];
System.arraycopy(source.bytes, start, string, 0, length);
return string;
case 2:
return loadEmbeddedString();
default:
throw new Error("Expected 0 or 1 but was " + buffer.get());
}
}
private ParseResult.Comment[] loadComments() {
int count = loadVarInt();
ParseResult.Comment[] comments = new ParseResult.Comment[count];
for (int i = 0; i < count; i++) {
ParseResult.CommentType type = ParseResult.CommentType.VALUES[buffer.get()];
Nodes.Location location = loadLocation();
ParseResult.Comment comment = new ParseResult.Comment(type, location);
comments[i] = comment;
}
return comments;
}
private ParseResult.Error[] loadSyntaxErrors() {
int count = loadVarInt();
ParseResult.Error[] errors = new ParseResult.Error[count];
// error messages only contain ASCII characters
for (int i = 0; i < count; i++) {
byte[] bytes = loadEmbeddedString();
String message = new String(bytes, StandardCharsets.US_ASCII);
Nodes.Location location = loadLocation();
ParseResult.Error error = new ParseResult.Error(message, location);
errors[i] = error;
}
return errors;
}
private ParseResult.Warning[] loadWarnings() {
int count = loadVarInt();
ParseResult.Warning[] warnings = new ParseResult.Warning[count];
// warning messages only contain ASCII characters
for (int i = 0; i < count; i++) {
byte[] bytes = loadEmbeddedString();
String message = new String(bytes, StandardCharsets.US_ASCII);
Nodes.Location location = loadLocation();
ParseResult.Warning warning = new ParseResult.Warning(message, location);
warnings[i] = warning;
}
return warnings;
}
private Nodes.Node loadOptionalNode() {
if (buffer.get(buffer.position()) != 0) {
return loadNode();
} else {
buffer.position(buffer.position() + 1); // continue after the 0 byte
return null;
}
}
private Nodes.Location[] loadLocations() {
int length = loadVarInt();
if (length == 0) {
return Nodes.Location.EMPTY_ARRAY;
}
Nodes.Location[] locations = new Nodes.Location[length];
for (int i = 0; i < length; i++) {
locations[i] = loadLocation();
}
return locations;
}
private byte[] loadConstant() {
return constantPool.get(buffer, loadVarInt());
}
private byte[][] loadConstants() {
int length = loadVarInt();
if (length == 0) {
return Nodes.EMPTY_BYTE_ARRAY_ARRAY;
}
byte[][] constants = new byte[length][];
for (int i = 0; i < length; i++) {
constants[i] = constantPool.get(buffer, loadVarInt());
}
return constants;
}
private Nodes.Node[] loadNodes() {
int length = loadVarInt();
if (length == 0) {
return Nodes.Node.EMPTY_ARRAY;
}
Nodes.Node[] nodes = new Nodes.Node[length];
for (int i = 0; i < length; i++) {
nodes[i] = loadNode();
}
return nodes;
}
private Nodes.Location loadLocation() {
return new Nodes.Location(loadVarInt(), loadVarInt());
}
private Nodes.Location loadOptionalLocation() {
if (buffer.get() != 0) {
return loadLocation();
} else {
return null;
}
}
// From https://github.com/protocolbuffers/protobuf/blob/v23.1/java/core/src/main/java/com/google/protobuf/BinaryReader.java#L1507
private int loadVarInt() {
int x;
if ((x = buffer.get()) >= 0) {
return x;
} else if ((x ^= (buffer.get() << 7)) < 0) {
x ^= (~0 << 7);
} else if ((x ^= (buffer.get() << 14)) >= 0) {
x ^= (~0 << 7) ^ (~0 << 14);
} else if ((x ^= (buffer.get() << 21)) < 0) {
x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21);
} else {
x ^= buffer.get() << 28;
x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28);
}
return x;
}
private short loadFlags() {
int flags = loadVarInt();
assert flags >= 0 && flags <= Short.MAX_VALUE;
return (short) flags;
}
private Nodes.Node loadNode() {
int type = buffer.get() & 0xFF;
int startOffset = loadVarInt();
int length = loadVarInt();
switch (type) {
<%- nodes.each_with_index do |node, index| -%>
case <%= index + 1 %>:
<%-
params = node.needs_serialized_length? ? ["buffer.getInt()"] : []
params.concat node.params.map { |param|
case param
when NodeParam then "#{param.java_cast}loadNode()"
when OptionalNodeParam then "#{param.java_cast}loadOptionalNode()"
when StringParam then "loadString()"
when NodeListParam then "loadNodes()"
when LocationListParam then "loadLocations()"
when ConstantParam then "loadConstant()"
when ConstantListParam then "loadConstants()"
when LocationParam then "loadLocation()"
when OptionalLocationParam then "loadOptionalLocation()"
when UInt32Param then "loadVarInt()"
when FlagsParam then "loadFlags()"
else raise
end
}
params.concat ["startOffset", "length"]
-%>
return new Nodes.<%= node.name %>(<%= params.join(", ") -%>);
<%- end -%>
default:
throw new Error("Unknown node type: " + type);
}
}
private void expect(byte value) {
byte b = buffer.get();
if (b != value) {
throw new Error("Expected " + value + " but was " + b + " at position " + buffer.position());
}
}
}
// @formatter:on