summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/prism.rb2
-rw-r--r--lib/prism/ffi.rb73
-rw-r--r--lib/prism/parse_result.rb24
-rw-r--r--prism/extension.c180
-rw-r--r--prism/templates/lib/prism/serialize.rb.erb445
5 files changed, 378 insertions, 346 deletions
diff --git a/lib/prism.rb b/lib/prism.rb
index 8024577fa3..6cae171f5e 100644
--- a/lib/prism.rb
+++ b/lib/prism.rb
@@ -63,7 +63,7 @@ module Prism
#
# Load the serialized AST using the source as a reference into a tree.
def self.load(source, serialized, freeze = false)
- Serialize.load(source, serialized, freeze)
+ Serialize.load_parse(source, serialized, freeze)
end
end
diff --git a/lib/prism/ffi.rb b/lib/prism/ffi.rb
index de11953a26..eda61b3ead 100644
--- a/lib/prism/ffi.rb
+++ b/lib/prism/ffi.rb
@@ -15,7 +15,8 @@ module Prism
# must align with the build shared library from make/rake.
libprism_in_build = File.expand_path("../../build/libprism.#{RbConfig::CONFIG["SOEXT"]}", __dir__)
libprism_in_libdir = "#{RbConfig::CONFIG["libdir"]}/prism/libprism.#{RbConfig::CONFIG["SOEXT"]}"
- if File.exist? libprism_in_build
+
+ if File.exist?(libprism_in_build)
INCLUDE_DIR = File.expand_path("../../include", __dir__)
ffi_lib libprism_in_build
else
@@ -363,86 +364,28 @@ module Prism
end
def lex_common(string, code, options) # :nodoc:
- serialized =
- LibRubyParser::PrismBuffer.with do |buffer|
- LibRubyParser.pm_serialize_lex(buffer.pointer, string.pointer, string.length, dump_options(options))
- buffer.read
- end
-
- freeze = options.fetch(:freeze, false)
- source = Source.for(code)
- result = Serialize.load_tokens(source, serialized, freeze)
-
- if freeze
- source.source.freeze
- source.offsets.freeze
- source.freeze
+ LibRubyParser::PrismBuffer.with do |buffer|
+ LibRubyParser.pm_serialize_lex(buffer.pointer, string.pointer, string.length, dump_options(options))
+ Serialize.load_lex(code, buffer.read, options.fetch(:freeze, false))
end
-
- result
end
def parse_common(string, code, options) # :nodoc:
serialized = dump_common(string, options)
- Prism.load(code, serialized, options.fetch(:freeze, false))
+ Serialize.load_parse(code, serialized, options.fetch(:freeze, false))
end
def parse_comments_common(string, code, options) # :nodoc:
LibRubyParser::PrismBuffer.with do |buffer|
LibRubyParser.pm_serialize_parse_comments(buffer.pointer, string.pointer, string.length, dump_options(options))
-
- source = Source.for(code)
- loader = Serialize::Loader.new(source, buffer.read)
-
- loader.load_header
- loader.load_encoding
- loader.load_start_line
-
- if (freeze = options.fetch(:freeze, false))
- source.source.freeze
- source.offsets.freeze
- source.freeze
- end
-
- loader.load_comments(freeze)
+ Serialize.load_parse_comments(code, buffer.read, options.fetch(:freeze, false))
end
end
def parse_lex_common(string, code, options) # :nodoc:
LibRubyParser::PrismBuffer.with do |buffer|
LibRubyParser.pm_serialize_parse_lex(buffer.pointer, string.pointer, string.length, dump_options(options))
-
- source = Source.for(code)
- loader = Serialize::Loader.new(source, buffer.read)
- freeze = options.fetch(:freeze, false)
-
- tokens = loader.load_tokens(false)
- node, comments, magic_comments, data_loc, errors, warnings = loader.load_nodes(freeze)
-
- tokens.each do |token,|
- token.value.force_encoding(loader.encoding)
-
- if freeze
- token.value.freeze
- token.location.freeze
- token.freeze
- end
- end
-
- value = [node, tokens]
- result = ParseLexResult.new(value, comments, magic_comments, data_loc, errors, warnings, source)
-
- if freeze
- source.source.freeze
- source.offsets.freeze
- source.freeze
- tokens.each(&:freeze)
- tokens.freeze
- value.freeze
- result.freeze
- end
-
- result
+ Serialize.load_parse_lex(code, buffer.read, options.fetch(:freeze, false))
end
end
diff --git a/lib/prism/parse_result.rb b/lib/prism/parse_result.rb
index 7aee20c9de..e76ea7e17e 100644
--- a/lib/prism/parse_result.rb
+++ b/lib/prism/parse_result.rb
@@ -48,6 +48,16 @@ module Prism
@offsets = offsets # set after parsing is done
end
+ # Replace the value of start_line with the given value.
+ def replace_start_line(start_line)
+ @start_line = start_line
+ end
+
+ # Replace the value of offsets with the given value.
+ def replace_offsets(offsets)
+ @offsets.replace(offsets)
+ end
+
# Returns the encoding of the source code, which is set by parameters to the
# parser or by the encoding magic comment.
def encoding
@@ -132,6 +142,13 @@ module Prism
code_units_offset(byte_offset, encoding) - code_units_offset(line_start(byte_offset), encoding)
end
+ # Freeze this object and the objects it contains.
+ def deep_freeze
+ source.freeze
+ offsets.freeze
+ freeze
+ end
+
private
# Binary search through the offsets to find the line number for the given
@@ -854,5 +871,12 @@ module Prism
location
super
end
+
+ # Freeze this object and the objects it contains.
+ def deep_freeze
+ value.freeze
+ location.freeze
+ freeze
+ end
end
end
diff --git a/prism/extension.c b/prism/extension.c
index 4503cea6bd..e8f678d341 100644
--- a/prism/extension.c
+++ b/prism/extension.c
@@ -388,28 +388,54 @@ dump_file(int argc, VALUE *argv, VALUE self) {
/******************************************************************************/
/**
- * Extract the comments out of the parser into an array.
+ * The same as rb_class_new_instance, but accepts an additional boolean to
+ * indicate whether or not the resulting class instance should be frozen.
*/
-static VALUE
-parser_comments(pm_parser_t *parser, VALUE source, bool freeze) {
- VALUE comments = rb_ary_new_capa(parser->comment_list.size);
+static inline VALUE
+rb_class_new_instance_freeze(int argc, const VALUE *argv, VALUE klass, bool freeze) {
+ VALUE value = rb_class_new_instance(argc, argv, klass);
+ if (freeze) rb_obj_freeze(value);
+ return value;
+}
- for (pm_comment_t *comment = (pm_comment_t *) parser->comment_list.head; comment != NULL; comment = (pm_comment_t *) comment->node.next) {
- VALUE location_argv[] = {
- source,
- LONG2FIX(comment->location.start - parser->start),
- LONG2FIX(comment->location.end - comment->location.start)
- };
+/**
+ * Create a new Location instance from the given parser and bounds.
+ */
+static inline VALUE
+parser_location(const pm_parser_t *parser, VALUE source, bool freeze, const uint8_t *start, size_t length) {
+ VALUE argv[] = { source, LONG2FIX(start - parser->start), LONG2FIX(length) };
+ return rb_class_new_instance_freeze(3, argv, rb_cPrismLocation, freeze);
+}
- VALUE location = rb_class_new_instance(3, location_argv, rb_cPrismLocation);
- if (freeze) rb_obj_freeze(location);
+/**
+ * Create a new Location instance from the given parser and location.
+ */
+#define PARSER_LOCATION_LOC(parser, source, freeze, loc) \
+ parser_location(parser, source, freeze, loc.start, (size_t) (loc.end - loc.start))
- VALUE comment_argv[] = { location };
- VALUE type = (comment->type == PM_COMMENT_EMBDOC) ? rb_cPrismEmbDocComment : rb_cPrismInlineComment;
+/**
+ * Build a new Comment instance from the given parser and comment.
+ */
+static inline VALUE
+parser_comment(const pm_parser_t *parser, VALUE source, bool freeze, const pm_comment_t *comment) {
+ VALUE argv[] = { PARSER_LOCATION_LOC(parser, source, freeze, comment->location) };
+ VALUE type = (comment->type == PM_COMMENT_EMBDOC) ? rb_cPrismEmbDocComment : rb_cPrismInlineComment;
+ return rb_class_new_instance_freeze(1, argv, type, freeze);
+}
- VALUE value = rb_class_new_instance(1, comment_argv, type);
- if (freeze) rb_obj_freeze(value);
+/**
+ * Extract the comments out of the parser into an array.
+ */
+static VALUE
+parser_comments(const pm_parser_t *parser, VALUE source, bool freeze) {
+ VALUE comments = rb_ary_new_capa(parser->comment_list.size);
+ for (
+ const pm_comment_t *comment = (const pm_comment_t *) parser->comment_list.head;
+ comment != NULL;
+ comment = (const pm_comment_t *) comment->node.next
+ ) {
+ VALUE value = parser_comment(parser, source, freeze, comment);
rb_ary_push(comments, value);
}
@@ -418,35 +444,29 @@ parser_comments(pm_parser_t *parser, VALUE source, bool freeze) {
}
/**
+ * Build a new MagicComment instance from the given parser and magic comment.
+ */
+static inline VALUE
+parser_magic_comment(const pm_parser_t *parser, VALUE source, bool freeze, const pm_magic_comment_t *magic_comment) {
+ VALUE key_loc = parser_location(parser, source, freeze, magic_comment->key_start, magic_comment->key_length);
+ VALUE value_loc = parser_location(parser, source, freeze, magic_comment->value_start, magic_comment->value_length);
+ VALUE argv[] = { key_loc, value_loc };
+ return rb_class_new_instance_freeze(2, argv, rb_cPrismMagicComment, freeze);
+}
+
+/**
* Extract the magic comments out of the parser into an array.
*/
static VALUE
-parser_magic_comments(pm_parser_t *parser, VALUE source, bool freeze) {
+parser_magic_comments(const pm_parser_t *parser, VALUE source, bool freeze) {
VALUE magic_comments = rb_ary_new_capa(parser->magic_comment_list.size);
- for (pm_magic_comment_t *magic_comment = (pm_magic_comment_t *) parser->magic_comment_list.head; magic_comment != NULL; magic_comment = (pm_magic_comment_t *) magic_comment->node.next) {
- VALUE key_loc_argv[] = {
- source,
- LONG2FIX(magic_comment->key_start - parser->start),
- LONG2FIX(magic_comment->key_length)
- };
-
- VALUE key_loc = rb_class_new_instance(3, key_loc_argv, rb_cPrismLocation);
- if (freeze) rb_obj_freeze(key_loc);
-
- VALUE value_loc_argv[] = {
- source,
- LONG2FIX(magic_comment->value_start - parser->start),
- LONG2FIX(magic_comment->value_length)
- };
-
- VALUE value_loc = rb_class_new_instance(3, value_loc_argv, rb_cPrismLocation);
- if (freeze) rb_obj_freeze(value_loc);
-
- VALUE magic_comment_argv[] = { key_loc, value_loc };
- VALUE value = rb_class_new_instance(2, magic_comment_argv, rb_cPrismMagicComment);
- if (freeze) rb_obj_freeze(value);
-
+ for (
+ const pm_magic_comment_t *magic_comment = (const pm_magic_comment_t *) parser->magic_comment_list.head;
+ magic_comment != NULL;
+ magic_comment = (const pm_magic_comment_t *) magic_comment->node.next
+ ) {
+ VALUE value = parser_magic_comment(parser, source, freeze, magic_comment);
rb_ary_push(magic_comments, value);
}
@@ -463,16 +483,7 @@ parser_data_loc(const pm_parser_t *parser, VALUE source, bool freeze) {
if (parser->data_loc.end == NULL) {
return Qnil;
} else {
- VALUE argv[] = {
- source,
- LONG2FIX(parser->data_loc.start - parser->start),
- LONG2FIX(parser->data_loc.end - parser->data_loc.start)
- };
-
- VALUE location = rb_class_new_instance(3, argv, rb_cPrismLocation);
- if (freeze) rb_obj_freeze(location);
-
- return location;
+ return PARSER_LOCATION_LOC(parser, source, freeze, parser->data_loc);
}
}
@@ -480,19 +491,17 @@ parser_data_loc(const pm_parser_t *parser, VALUE source, bool freeze) {
* Extract the errors out of the parser into an array.
*/
static VALUE
-parser_errors(pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool freeze) {
+parser_errors(const pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool freeze) {
VALUE errors = rb_ary_new_capa(parser->error_list.size);
- pm_diagnostic_t *error;
-
- for (error = (pm_diagnostic_t *) parser->error_list.head; error != NULL; error = (pm_diagnostic_t *) error->node.next) {
- VALUE location_argv[] = {
- source,
- LONG2FIX(error->location.start - parser->start),
- LONG2FIX(error->location.end - error->location.start)
- };
- VALUE location = rb_class_new_instance(3, location_argv, rb_cPrismLocation);
- if (freeze) rb_obj_freeze(location);
+ for (
+ const pm_diagnostic_t *error = (const pm_diagnostic_t *) parser->error_list.head;
+ error != NULL;
+ error = (const pm_diagnostic_t *) error->node.next
+ ) {
+ VALUE type = ID2SYM(rb_intern(pm_diagnostic_id_human(error->diag_id)));
+ VALUE message = rb_obj_freeze(rb_enc_str_new_cstr(error->message, encoding));
+ VALUE location = PARSER_LOCATION_LOC(parser, source, freeze, error->location);
VALUE level = Qnil;
switch (error->level) {
@@ -509,15 +518,8 @@ parser_errors(pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool fre
rb_raise(rb_eRuntimeError, "Unknown level: %" PRIu8, error->level);
}
- VALUE message = rb_enc_str_new_cstr(error->message, encoding);
- if (freeze) rb_obj_freeze(message);
-
- VALUE type = ID2SYM(rb_intern(pm_diagnostic_id_human(error->diag_id)));
- VALUE error_argv[] = { type, message, location, level };
-
- VALUE value = rb_class_new_instance(4, error_argv, rb_cPrismParseError);
- if (freeze) rb_obj_freeze(value);
-
+ VALUE argv[] = { type, message, location, level };
+ VALUE value = rb_class_new_instance_freeze(4, argv, rb_cPrismParseError, freeze);
rb_ary_push(errors, value);
}
@@ -529,19 +531,17 @@ parser_errors(pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool fre
* Extract the warnings out of the parser into an array.
*/
static VALUE
-parser_warnings(pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool freeze) {
+parser_warnings(const pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool freeze) {
VALUE warnings = rb_ary_new_capa(parser->warning_list.size);
- pm_diagnostic_t *warning;
- for (warning = (pm_diagnostic_t *) parser->warning_list.head; warning != NULL; warning = (pm_diagnostic_t *) warning->node.next) {
- VALUE location_argv[] = {
- source,
- LONG2FIX(warning->location.start - parser->start),
- LONG2FIX(warning->location.end - warning->location.start)
- };
-
- VALUE location = rb_class_new_instance(3, location_argv, rb_cPrismLocation);
- if (freeze) rb_obj_freeze(location);
+ for (
+ const pm_diagnostic_t *warning = (const pm_diagnostic_t *) parser->warning_list.head;
+ warning != NULL;
+ warning = (const pm_diagnostic_t *) warning->node.next
+ ) {
+ VALUE type = ID2SYM(rb_intern(pm_diagnostic_id_human(warning->diag_id)));
+ VALUE message = rb_obj_freeze(rb_enc_str_new_cstr(warning->message, encoding));
+ VALUE location = PARSER_LOCATION_LOC(parser, source, freeze, warning->location);
VALUE level = Qnil;
switch (warning->level) {
@@ -555,15 +555,8 @@ parser_warnings(pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool f
rb_raise(rb_eRuntimeError, "Unknown level: %" PRIu8, warning->level);
}
- VALUE message = rb_enc_str_new_cstr(warning->message, encoding);
- if (freeze) rb_obj_freeze(message);
-
- VALUE type = ID2SYM(rb_intern(pm_diagnostic_id_human(warning->diag_id)));
- VALUE warning_argv[] = { type, message, location, level };
-
- VALUE value = rb_class_new_instance(4, warning_argv, rb_cPrismParseWarning);
- if (freeze) rb_obj_freeze(value);
-
+ VALUE argv[] = { type, message, location, level };
+ VALUE value = rb_class_new_instance_freeze(4, argv, rb_cPrismParseWarning, freeze);
rb_ary_push(warnings, value);
}
@@ -575,7 +568,7 @@ parser_warnings(pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool f
* Create a new parse result from the given parser, value, encoding, and source.
*/
static VALUE
-parse_result_create(VALUE class, pm_parser_t *parser, VALUE value, rb_encoding *encoding, VALUE source, bool freeze) {
+parse_result_create(VALUE class, const pm_parser_t *parser, VALUE value, rb_encoding *encoding, VALUE source, bool freeze) {
VALUE result_argv[] = {
value,
parser_comments(parser, source, freeze),
@@ -586,10 +579,7 @@ parse_result_create(VALUE class, pm_parser_t *parser, VALUE value, rb_encoding *
source
};
- VALUE result = rb_class_new_instance(7, result_argv, class);
- if (freeze) rb_obj_freeze(result);
-
- return result;
+ return rb_class_new_instance_freeze(7, result_argv, class, freeze);
}
/******************************************************************************/
diff --git a/prism/templates/lib/prism/serialize.rb.erb b/prism/templates/lib/prism/serialize.rb.erb
index e8ac12830a..52821e0f7d 100644
--- a/prism/templates/lib/prism/serialize.rb.erb
+++ b/prism/templates/lib/prism/serialize.rb.erb
@@ -16,15 +16,41 @@ module Prism
# strings.
PATCH_VERSION = 0
- # Deserialize the AST represented by the given string into a parse result.
- def self.load(input, serialized, freeze)
+ # Deserialize the dumped output from a request to parse or parse_file.
+ #
+ # The formatting of the source of this method is purposeful to illustrate
+ # the structure of the serialized data.
+ def self.load_parse(input, serialized, freeze)
input = input.dup
source = Source.for(input)
-
loader = Loader.new(source, serialized)
- result = loader.load_result(freeze)
- input.force_encoding(loader.encoding)
+ loader.load_header
+ encoding = loader.load_encoding
+ start_line = loader.load_varsint
+ offsets = loader.load_line_offsets(freeze)
+
+ source.replace_start_line(start_line)
+ source.replace_offsets(offsets)
+
+ comments = loader.load_comments(freeze)
+ magic_comments = loader.load_magic_comments(freeze)
+ data_loc = loader.load_optional_location_object(freeze)
+ errors = loader.load_errors(encoding, freeze)
+ warnings = loader.load_warnings(encoding, freeze)
+ cpool_base = loader.load_uint32
+ cpool_size = loader.load_varuint
+
+ constant_pool = ConstantPool.new(input, serialized, cpool_base, cpool_size)
+
+ node = loader.load_node(constant_pool, encoding, freeze)
+ loader.load_constant_pool(constant_pool)
+ raise unless loader.eof?
+
+ result = ParseResult.new(node, comments, magic_comments, data_loc, errors, warnings, source)
+ result.freeze if freeze
+
+ input.force_encoding(encoding)
# This is an extremely niche use-case where the file was marked as binary
# but it contained UTF-8-encoded characters. In that case we will actually
@@ -37,94 +63,231 @@ module Prism
if freeze
input.freeze
- source.source.freeze
- source.offsets.freeze
- source.freeze
+ source.deep_freeze
end
result
end
- # Deserialize the tokens represented by the given string into a parse
- # result.
- def self.load_tokens(source, serialized, freeze)
- Loader.new(source, serialized).load_tokens_result(freeze)
+ # Deserialize the dumped output from a request to lex or lex_file.
+ #
+ # The formatting of the source of this method is purposeful to illustrate
+ # the structure of the serialized data.
+ def self.load_lex(input, serialized, freeze)
+ source = Source.for(input)
+ loader = Loader.new(source, serialized)
+
+ tokens = loader.load_tokens
+ encoding = loader.load_encoding
+ start_line = loader.load_varsint
+ offsets = loader.load_line_offsets(freeze)
+
+ source.replace_start_line(start_line)
+ source.replace_offsets(offsets)
+
+ comments = loader.load_comments(freeze)
+ magic_comments = loader.load_magic_comments(freeze)
+ data_loc = loader.load_optional_location_object(freeze)
+ errors = loader.load_errors(encoding, freeze)
+ warnings = loader.load_warnings(encoding, freeze)
+ raise unless loader.eof?
+
+ result = LexResult.new(tokens, comments, magic_comments, data_loc, errors, warnings, source)
+
+ tokens.each do |token|
+ token[0].value.force_encoding(encoding)
+
+ if freeze
+ token[0].deep_freeze
+ token.freeze
+ end
+ end
+
+ if freeze
+ source.deep_freeze
+ tokens.freeze
+ result.freeze
+ end
+
+ result
end
- class Loader # :nodoc:
- if RUBY_ENGINE == "truffleruby"
- # StringIO is synchronized and that adds a high overhead on TruffleRuby.
- class FastStringIO # :nodoc:
- attr_accessor :pos
-
- def initialize(string)
- @string = string
- @pos = 0
- end
+ # Deserialize the dumped output from a request to parse_comments or
+ # parse_file_comments.
+ #
+ # The formatting of the source of this method is purposeful to illustrate
+ # the structure of the serialized data.
+ def self.load_parse_comments(input, serialized, freeze)
+ source = Source.for(input)
+ loader = Loader.new(source, serialized)
- def getbyte
- byte = @string.getbyte(@pos)
- @pos += 1
- byte
- end
+ loader.load_header
+ loader.load_encoding
+ start_line = loader.load_varsint
- def read(n)
- slice = @string.byteslice(@pos, n)
- @pos += n
- slice
- end
+ source.replace_start_line(start_line)
+
+ result = loader.load_comments(freeze)
+ raise unless loader.eof?
+
+ source.deep_freeze if freeze
+ result
+ end
+
+ # Deserialize the dumped output from a request to parse_lex or
+ # parse_lex_file.
+ #
+ # The formatting of the source of this method is purposeful to illustrate
+ # the structure of the serialized data.
+ def self.load_parse_lex(input, serialized, freeze)
+ source = Source.for(input)
+ loader = Loader.new(source, serialized)
+
+ tokens = loader.load_tokens
+ loader.load_header
+ encoding = loader.load_encoding
+ start_line = loader.load_varsint
+ offsets = loader.load_line_offsets(freeze)
+
+ source.replace_start_line(start_line)
+ source.replace_offsets(offsets)
+
+ comments = loader.load_comments(freeze)
+ magic_comments = loader.load_magic_comments(freeze)
+ data_loc = loader.load_optional_location_object(freeze)
+ errors = loader.load_errors(encoding, freeze)
+ warnings = loader.load_warnings(encoding, freeze)
+ cpool_base = loader.load_uint32
+ cpool_size = loader.load_varuint
+
+ constant_pool = ConstantPool.new(input, serialized, cpool_base, cpool_size)
+
+ node = loader.load_node(constant_pool, encoding, freeze)
+ loader.load_constant_pool(constant_pool)
+ raise unless loader.eof?
+
+ value = [node, tokens]
+ result = ParseLexResult.new(value, comments, magic_comments, data_loc, errors, warnings, source)
+
+ tokens.each do |token|
+ token[0].value.force_encoding(encoding)
+
+ if freeze
+ token[0].deep_freeze
+ token.freeze
+ end
+ end
- def eof?
- @pos >= @string.bytesize
+ if freeze
+ source.deep_freeze
+ tokens.freeze
+ value.freeze
+ result.freeze
+ end
+
+ result
+ end
+
+ class ConstantPool # :nodoc:
+ attr_reader :size
+
+ def initialize(input, serialized, base, size)
+ @input = input
+ @serialized = serialized
+ @base = base
+ @size = size
+ @pool = Array.new(size, nil)
+ end
+
+ def get(index, encoding)
+ @pool[index] ||=
+ begin
+ offset = @base + index * 8
+ start = @serialized.unpack1("L", offset: offset)
+ length = @serialized.unpack1("L", offset: offset + 4)
+
+ if start.nobits?(1 << 31)
+ @input.byteslice(start, length).force_encoding(encoding).to_sym
+ else
+ @serialized.byteslice(start & ((1 << 31) - 1), length).force_encoding(encoding).to_sym
+ end
end
+ end
+ end
+
+ if RUBY_ENGINE == "truffleruby"
+ # StringIO is synchronized and that adds a high overhead on TruffleRuby.
+ class FastStringIO # :nodoc:
+ attr_accessor :pos
+
+ def initialize(string)
+ @string = string
+ @pos = 0
+ end
+
+ def getbyte
+ byte = @string.getbyte(@pos)
+ @pos += 1
+ byte
+ end
+
+ def read(n)
+ slice = @string.byteslice(@pos, n)
+ @pos += n
+ slice
+ end
+
+ def eof?
+ @pos >= @string.bytesize
end
- else
- FastStringIO = ::StringIO
end
- private_constant :FastStringIO
+ else
+ FastStringIO = ::StringIO # :nodoc:
+ end
- attr_reader :encoding, :input, :serialized, :io
- attr_reader :constant_pool_offset, :constant_pool, :source
- attr_reader :start_line
+ class Loader # :nodoc:
+ attr_reader :input, :io, :source
def initialize(source, serialized)
- @encoding = Encoding::UTF_8
-
@input = source.source.dup
raise unless serialized.encoding == Encoding::BINARY
- @serialized = serialized
@io = FastStringIO.new(serialized)
+ @source = source
+ define_load_node_lambdas if RUBY_ENGINE != "ruby"
+ end
- @constant_pool_offset = nil
- @constant_pool = nil
+ def eof?
+ io.getbyte
+ io.eof?
+ end
- @source = source
- define_load_node_lambdas unless RUBY_ENGINE == "ruby"
+ def load_constant_pool(constant_pool)
+ trailer = 0
+
+ constant_pool.size.times do |index|
+ start, length = io.read(8).unpack("L2")
+ trailer += length if start.anybits?(1 << 31)
+ end
+
+ io.read(trailer)
end
def load_header
raise "Invalid serialization" if io.read(5) != "PRISM"
raise "Invalid serialization" if io.read(3).unpack("C3") != [MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION]
- only_semantic_fields = io.getbyte
- unless only_semantic_fields == 0
- raise "Invalid serialization (location fields must be included but are not)"
- end
+ raise "Invalid serialization (location fields must be included but are not)" if io.getbyte != 0
end
def load_encoding
- @encoding = Encoding.find(io.read(load_varuint))
- @input = input.force_encoding(@encoding).freeze
- @encoding
- end
-
- def load_start_line
- source.instance_variable_set(:@start_line, load_varsint)
+ encoding = Encoding.find(io.read(load_varuint))
+ @input = input.force_encoding(encoding).freeze
+ encoding
end
def load_line_offsets(freeze)
offsets = Array.new(load_varuint) { load_varuint }
offsets.freeze if freeze
- source.instance_variable_set(:@offsets, offsets)
+ offsets
end
def load_comments(freeze)
@@ -187,13 +350,13 @@ module Prism
end
end
- def load_errors(freeze)
+ def load_errors(encoding, freeze)
errors =
Array.new(load_varuint) do
error =
ParseError.new(
DIAGNOSTIC_TYPES.fetch(load_varuint),
- load_embedded_string,
+ load_embedded_string(encoding),
load_location_object(freeze),
load_error_level
)
@@ -219,13 +382,13 @@ module Prism
end
end
- def load_warnings(freeze)
+ def load_warnings(encoding, freeze)
warnings =
Array.new(load_varuint) do
warning =
ParseWarning.new(
DIAGNOSTIC_TYPES.fetch(load_varuint),
- load_embedded_string,
+ load_embedded_string(encoding),
load_location_object(freeze),
load_warning_level
)
@@ -238,17 +401,7 @@ module Prism
warnings
end
- def load_metadata(freeze)
- [
- load_comments(freeze),
- load_magic_comments(freeze),
- load_optional_location_object(freeze),
- load_errors(freeze),
- load_warnings(freeze)
- ]
- end
-
- def load_tokens(freeze)
+ def load_tokens
tokens = []
while (type = TOKEN_TYPES.fetch(load_varuint))
@@ -257,74 +410,14 @@ module Prism
lex_state = load_varuint
location = Location.new(@source, start, length)
- location.freeze if freeze
-
- slice = location.slice
- slice.freeze if freeze
-
- token = Token.new(@source, type, slice, location)
- token.freeze if freeze
+ token = Token.new(@source, type, location.slice, location)
tokens << [token, lex_state]
end
- tokens.freeze if freeze
tokens
end
- def load_tokens_result(freeze)
- tokens = load_tokens(false)
- encoding = load_encoding
- load_start_line
- load_line_offsets(freeze)
- comments, magic_comments, data_loc, errors, warnings = load_metadata(freeze)
-
- tokens.each do |token,|
- token.value.force_encoding(encoding)
-
- if freeze
- token.value.freeze
- token.location.freeze
- token.freeze
- end
- end
-
- raise "Expected to consume all bytes while deserializing" unless @io.eof?
- result = LexResult.new(tokens, comments, magic_comments, data_loc, errors, warnings, @source)
-
- if freeze
- tokens.each(&:freeze)
- tokens.freeze
- result.freeze
- end
-
- result
- end
-
- def load_nodes(freeze)
- load_header
- load_encoding
- load_start_line
- load_line_offsets(freeze)
-
- comments, magic_comments, data_loc, errors, warnings = load_metadata(freeze)
-
- @constant_pool_offset = load_uint32
- @constant_pool = Array.new(load_varuint, nil)
-
- [load_node(freeze), comments, magic_comments, data_loc, errors, warnings]
- end
-
- def load_result(freeze)
- node, comments, magic_comments, data_loc, errors, warnings = load_nodes(freeze)
- result = ParseResult.new(node, comments, magic_comments, data_loc, errors, warnings, @source)
-
- result.freeze if freeze
- result
- end
-
- private
-
# variable-length integer using https://en.wikipedia.org/wiki/LEB128
# This is also what protobuf uses: https://protobuf.dev/programming-guides/encoding/#varints
def load_varuint
@@ -365,23 +458,23 @@ module Prism
io.read(4).unpack1("L")
end
- def load_optional_node(freeze)
+ def load_optional_node(constant_pool, encoding, freeze)
if io.getbyte != 0
io.pos -= 1
- load_node(freeze)
+ load_node(constant_pool, encoding, freeze)
end
end
- def load_embedded_string
+ def load_embedded_string(encoding)
io.read(load_varuint).force_encoding(encoding).freeze
end
- def load_string
+ def load_string(encoding)
case (type = io.getbyte)
when 1
input.byteslice(load_varuint, load_varuint).force_encoding(encoding).freeze
when 2
- load_embedded_string
+ load_embedded_string(encoding)
else
raise "Unknown serialized string type: #{type}"
end
@@ -406,38 +499,18 @@ module Prism
load_location_object(freeze) if io.getbyte != 0
end
- def load_constant(index)
- constant = constant_pool[index]
-
- unless constant
- offset = constant_pool_offset + index * 8
- start = @serialized.unpack1("L", offset: offset)
- length = @serialized.unpack1("L", offset: offset + 4)
-
- constant =
- if start.nobits?(1 << 31)
- input.byteslice(start, length).force_encoding(@encoding).to_sym
- else
- @serialized.byteslice(start & ((1 << 31) - 1), length).force_encoding(@encoding).to_sym
- end
-
- constant_pool[index] = constant
- end
-
- constant
- end
-
- def load_required_constant
- load_constant(load_varuint - 1)
+ def load_constant(constant_pool, encoding)
+ index = load_varuint
+ constant_pool.get(index - 1, encoding)
end
- def load_optional_constant
+ def load_optional_constant(constant_pool, encoding)
index = load_varuint
- load_constant(index - 1) if index != 0
+ constant_pool.get(index - 1, encoding) if index != 0
end
if RUBY_ENGINE == "ruby"
- def load_node(freeze)
+ def load_node(constant_pool, encoding, freeze)
type = io.getbyte
node_id = load_varuint
location = load_location(freeze)
@@ -449,13 +522,13 @@ module Prism
<%- end -%>
<%= node.name %>.new(<%= ["source", "node_id", "location", "load_varuint", *node.fields.map { |field|
case field
- when Prism::Template::NodeField then "load_node(freeze)"
- when Prism::Template::OptionalNodeField then "load_optional_node(freeze)"
- when Prism::Template::StringField then "load_string"
- when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node(freeze) }.tap { |nodes| nodes.freeze if freeze }"
- when Prism::Template::ConstantField then "load_required_constant"
- when Prism::Template::OptionalConstantField then "load_optional_constant"
- when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_required_constant }.tap { |constants| constants.freeze if freeze }"
+ when Prism::Template::NodeField then "load_node(constant_pool, encoding, freeze)"
+ when Prism::Template::OptionalNodeField then "load_optional_node(constant_pool, encoding, freeze)"
+ when Prism::Template::StringField then "load_string(encoding)"
+ when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node(constant_pool, encoding, freeze) }.tap { |nodes| nodes.freeze if freeze }"
+ when Prism::Template::ConstantField then "load_constant(constant_pool, encoding)"
+ when Prism::Template::OptionalConstantField then "load_optional_constant(constant_pool, encoding)"
+ when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_constant(constant_pool, encoding) }.tap { |constants| constants.freeze if freeze }"
when Prism::Template::LocationField then "load_location(freeze)"
when Prism::Template::OptionalLocationField then "load_optional_location(freeze)"
when Prism::Template::UInt8Field then "io.getbyte"
@@ -472,16 +545,15 @@ module Prism
value
end
else
- def load_node(freeze)
- type = io.getbyte
- @load_node_lambdas[type].call(freeze)
+ def load_node(constant_pool, encoding, freeze)
+ @load_node_lambdas[io.getbyte].call(constant_pool, encoding, freeze)
end
def define_load_node_lambdas
@load_node_lambdas = [
nil,
<%- nodes.each do |node| -%>
- -> (freeze) {
+ -> (constant_pool, encoding, freeze) {
node_id = load_varuint
location = load_location(freeze)
<%- if node.needs_serialized_length? -%>
@@ -489,13 +561,13 @@ module Prism
<%- end -%>
value = <%= node.name %>.new(<%= ["source", "node_id", "location", "load_varuint", *node.fields.map { |field|
case field
- when Prism::Template::NodeField then "load_node(freeze)"
- when Prism::Template::OptionalNodeField then "load_optional_node(freeze)"
- when Prism::Template::StringField then "load_string"
- when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node(freeze) }"
- when Prism::Template::ConstantField then "load_required_constant"
- when Prism::Template::OptionalConstantField then "load_optional_constant"
- when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_required_constant }"
+ when Prism::Template::NodeField then "load_node(constant_pool, encoding, freeze)"
+ when Prism::Template::OptionalNodeField then "load_optional_node(constant_pool, encoding, freeze)"
+ when Prism::Template::StringField then "load_string(encoding)"
+ when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node(constant_pool, encoding, freeze) }"
+ when Prism::Template::ConstantField then "load_constant(constant_pool, encoding)"
+ when Prism::Template::OptionalConstantField then "load_optional_constant(constant_pool, encoding)"
+ when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_constant(constant_pool, encoding) }"
when Prism::Template::LocationField then "load_location(freeze)"
when Prism::Template::OptionalLocationField then "load_optional_location(freeze)"
when Prism::Template::UInt8Field then "io.getbyte"
@@ -522,6 +594,9 @@ module Prism
<%- end -%>
]
- private_constant :TOKEN_TYPES
+ private_constant :MAJOR_VERSION, :MINOR_VERSION, :PATCH_VERSION
+ private_constant :ConstantPool, :FastStringIO, :Loader, :TOKEN_TYPES
end
+
+ private_constant :Serialize
end