diff options
-rw-r--r-- | lib/prism.rb | 2 | ||||
-rw-r--r-- | lib/prism/ffi.rb | 73 | ||||
-rw-r--r-- | lib/prism/parse_result.rb | 24 | ||||
-rw-r--r-- | prism/extension.c | 180 | ||||
-rw-r--r-- | prism/templates/lib/prism/serialize.rb.erb | 445 |
5 files changed, 378 insertions, 346 deletions
diff --git a/lib/prism.rb b/lib/prism.rb index 8024577fa3..6cae171f5e 100644 --- a/lib/prism.rb +++ b/lib/prism.rb @@ -63,7 +63,7 @@ module Prism # # Load the serialized AST using the source as a reference into a tree. def self.load(source, serialized, freeze = false) - Serialize.load(source, serialized, freeze) + Serialize.load_parse(source, serialized, freeze) end end diff --git a/lib/prism/ffi.rb b/lib/prism/ffi.rb index de11953a26..eda61b3ead 100644 --- a/lib/prism/ffi.rb +++ b/lib/prism/ffi.rb @@ -15,7 +15,8 @@ module Prism # must align with the build shared library from make/rake. libprism_in_build = File.expand_path("../../build/libprism.#{RbConfig::CONFIG["SOEXT"]}", __dir__) libprism_in_libdir = "#{RbConfig::CONFIG["libdir"]}/prism/libprism.#{RbConfig::CONFIG["SOEXT"]}" - if File.exist? libprism_in_build + + if File.exist?(libprism_in_build) INCLUDE_DIR = File.expand_path("../../include", __dir__) ffi_lib libprism_in_build else @@ -363,86 +364,28 @@ module Prism end def lex_common(string, code, options) # :nodoc: - serialized = - LibRubyParser::PrismBuffer.with do |buffer| - LibRubyParser.pm_serialize_lex(buffer.pointer, string.pointer, string.length, dump_options(options)) - buffer.read - end - - freeze = options.fetch(:freeze, false) - source = Source.for(code) - result = Serialize.load_tokens(source, serialized, freeze) - - if freeze - source.source.freeze - source.offsets.freeze - source.freeze + LibRubyParser::PrismBuffer.with do |buffer| + LibRubyParser.pm_serialize_lex(buffer.pointer, string.pointer, string.length, dump_options(options)) + Serialize.load_lex(code, buffer.read, options.fetch(:freeze, false)) end - - result end def parse_common(string, code, options) # :nodoc: serialized = dump_common(string, options) - Prism.load(code, serialized, options.fetch(:freeze, false)) + Serialize.load_parse(code, serialized, options.fetch(:freeze, false)) end def parse_comments_common(string, code, options) # :nodoc: LibRubyParser::PrismBuffer.with do |buffer| LibRubyParser.pm_serialize_parse_comments(buffer.pointer, string.pointer, string.length, dump_options(options)) - - source = Source.for(code) - loader = Serialize::Loader.new(source, buffer.read) - - loader.load_header - loader.load_encoding - loader.load_start_line - - if (freeze = options.fetch(:freeze, false)) - source.source.freeze - source.offsets.freeze - source.freeze - end - - loader.load_comments(freeze) + Serialize.load_parse_comments(code, buffer.read, options.fetch(:freeze, false)) end end def parse_lex_common(string, code, options) # :nodoc: LibRubyParser::PrismBuffer.with do |buffer| LibRubyParser.pm_serialize_parse_lex(buffer.pointer, string.pointer, string.length, dump_options(options)) - - source = Source.for(code) - loader = Serialize::Loader.new(source, buffer.read) - freeze = options.fetch(:freeze, false) - - tokens = loader.load_tokens(false) - node, comments, magic_comments, data_loc, errors, warnings = loader.load_nodes(freeze) - - tokens.each do |token,| - token.value.force_encoding(loader.encoding) - - if freeze - token.value.freeze - token.location.freeze - token.freeze - end - end - - value = [node, tokens] - result = ParseLexResult.new(value, comments, magic_comments, data_loc, errors, warnings, source) - - if freeze - source.source.freeze - source.offsets.freeze - source.freeze - tokens.each(&:freeze) - tokens.freeze - value.freeze - result.freeze - end - - result + Serialize.load_parse_lex(code, buffer.read, options.fetch(:freeze, false)) end end diff --git a/lib/prism/parse_result.rb b/lib/prism/parse_result.rb index 7aee20c9de..e76ea7e17e 100644 --- a/lib/prism/parse_result.rb +++ b/lib/prism/parse_result.rb @@ -48,6 +48,16 @@ module Prism @offsets = offsets # set after parsing is done end + # Replace the value of start_line with the given value. + def replace_start_line(start_line) + @start_line = start_line + end + + # Replace the value of offsets with the given value. + def replace_offsets(offsets) + @offsets.replace(offsets) + end + # Returns the encoding of the source code, which is set by parameters to the # parser or by the encoding magic comment. def encoding @@ -132,6 +142,13 @@ module Prism code_units_offset(byte_offset, encoding) - code_units_offset(line_start(byte_offset), encoding) end + # Freeze this object and the objects it contains. + def deep_freeze + source.freeze + offsets.freeze + freeze + end + private # Binary search through the offsets to find the line number for the given @@ -854,5 +871,12 @@ module Prism location super end + + # Freeze this object and the objects it contains. + def deep_freeze + value.freeze + location.freeze + freeze + end end end diff --git a/prism/extension.c b/prism/extension.c index 4503cea6bd..e8f678d341 100644 --- a/prism/extension.c +++ b/prism/extension.c @@ -388,28 +388,54 @@ dump_file(int argc, VALUE *argv, VALUE self) { /******************************************************************************/ /** - * Extract the comments out of the parser into an array. + * The same as rb_class_new_instance, but accepts an additional boolean to + * indicate whether or not the resulting class instance should be frozen. */ -static VALUE -parser_comments(pm_parser_t *parser, VALUE source, bool freeze) { - VALUE comments = rb_ary_new_capa(parser->comment_list.size); +static inline VALUE +rb_class_new_instance_freeze(int argc, const VALUE *argv, VALUE klass, bool freeze) { + VALUE value = rb_class_new_instance(argc, argv, klass); + if (freeze) rb_obj_freeze(value); + return value; +} - for (pm_comment_t *comment = (pm_comment_t *) parser->comment_list.head; comment != NULL; comment = (pm_comment_t *) comment->node.next) { - VALUE location_argv[] = { - source, - LONG2FIX(comment->location.start - parser->start), - LONG2FIX(comment->location.end - comment->location.start) - }; +/** + * Create a new Location instance from the given parser and bounds. + */ +static inline VALUE +parser_location(const pm_parser_t *parser, VALUE source, bool freeze, const uint8_t *start, size_t length) { + VALUE argv[] = { source, LONG2FIX(start - parser->start), LONG2FIX(length) }; + return rb_class_new_instance_freeze(3, argv, rb_cPrismLocation, freeze); +} - VALUE location = rb_class_new_instance(3, location_argv, rb_cPrismLocation); - if (freeze) rb_obj_freeze(location); +/** + * Create a new Location instance from the given parser and location. + */ +#define PARSER_LOCATION_LOC(parser, source, freeze, loc) \ + parser_location(parser, source, freeze, loc.start, (size_t) (loc.end - loc.start)) - VALUE comment_argv[] = { location }; - VALUE type = (comment->type == PM_COMMENT_EMBDOC) ? rb_cPrismEmbDocComment : rb_cPrismInlineComment; +/** + * Build a new Comment instance from the given parser and comment. + */ +static inline VALUE +parser_comment(const pm_parser_t *parser, VALUE source, bool freeze, const pm_comment_t *comment) { + VALUE argv[] = { PARSER_LOCATION_LOC(parser, source, freeze, comment->location) }; + VALUE type = (comment->type == PM_COMMENT_EMBDOC) ? rb_cPrismEmbDocComment : rb_cPrismInlineComment; + return rb_class_new_instance_freeze(1, argv, type, freeze); +} - VALUE value = rb_class_new_instance(1, comment_argv, type); - if (freeze) rb_obj_freeze(value); +/** + * Extract the comments out of the parser into an array. + */ +static VALUE +parser_comments(const pm_parser_t *parser, VALUE source, bool freeze) { + VALUE comments = rb_ary_new_capa(parser->comment_list.size); + for ( + const pm_comment_t *comment = (const pm_comment_t *) parser->comment_list.head; + comment != NULL; + comment = (const pm_comment_t *) comment->node.next + ) { + VALUE value = parser_comment(parser, source, freeze, comment); rb_ary_push(comments, value); } @@ -418,35 +444,29 @@ parser_comments(pm_parser_t *parser, VALUE source, bool freeze) { } /** + * Build a new MagicComment instance from the given parser and magic comment. + */ +static inline VALUE +parser_magic_comment(const pm_parser_t *parser, VALUE source, bool freeze, const pm_magic_comment_t *magic_comment) { + VALUE key_loc = parser_location(parser, source, freeze, magic_comment->key_start, magic_comment->key_length); + VALUE value_loc = parser_location(parser, source, freeze, magic_comment->value_start, magic_comment->value_length); + VALUE argv[] = { key_loc, value_loc }; + return rb_class_new_instance_freeze(2, argv, rb_cPrismMagicComment, freeze); +} + +/** * Extract the magic comments out of the parser into an array. */ static VALUE -parser_magic_comments(pm_parser_t *parser, VALUE source, bool freeze) { +parser_magic_comments(const pm_parser_t *parser, VALUE source, bool freeze) { VALUE magic_comments = rb_ary_new_capa(parser->magic_comment_list.size); - for (pm_magic_comment_t *magic_comment = (pm_magic_comment_t *) parser->magic_comment_list.head; magic_comment != NULL; magic_comment = (pm_magic_comment_t *) magic_comment->node.next) { - VALUE key_loc_argv[] = { - source, - LONG2FIX(magic_comment->key_start - parser->start), - LONG2FIX(magic_comment->key_length) - }; - - VALUE key_loc = rb_class_new_instance(3, key_loc_argv, rb_cPrismLocation); - if (freeze) rb_obj_freeze(key_loc); - - VALUE value_loc_argv[] = { - source, - LONG2FIX(magic_comment->value_start - parser->start), - LONG2FIX(magic_comment->value_length) - }; - - VALUE value_loc = rb_class_new_instance(3, value_loc_argv, rb_cPrismLocation); - if (freeze) rb_obj_freeze(value_loc); - - VALUE magic_comment_argv[] = { key_loc, value_loc }; - VALUE value = rb_class_new_instance(2, magic_comment_argv, rb_cPrismMagicComment); - if (freeze) rb_obj_freeze(value); - + for ( + const pm_magic_comment_t *magic_comment = (const pm_magic_comment_t *) parser->magic_comment_list.head; + magic_comment != NULL; + magic_comment = (const pm_magic_comment_t *) magic_comment->node.next + ) { + VALUE value = parser_magic_comment(parser, source, freeze, magic_comment); rb_ary_push(magic_comments, value); } @@ -463,16 +483,7 @@ parser_data_loc(const pm_parser_t *parser, VALUE source, bool freeze) { if (parser->data_loc.end == NULL) { return Qnil; } else { - VALUE argv[] = { - source, - LONG2FIX(parser->data_loc.start - parser->start), - LONG2FIX(parser->data_loc.end - parser->data_loc.start) - }; - - VALUE location = rb_class_new_instance(3, argv, rb_cPrismLocation); - if (freeze) rb_obj_freeze(location); - - return location; + return PARSER_LOCATION_LOC(parser, source, freeze, parser->data_loc); } } @@ -480,19 +491,17 @@ parser_data_loc(const pm_parser_t *parser, VALUE source, bool freeze) { * Extract the errors out of the parser into an array. */ static VALUE -parser_errors(pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool freeze) { +parser_errors(const pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool freeze) { VALUE errors = rb_ary_new_capa(parser->error_list.size); - pm_diagnostic_t *error; - - for (error = (pm_diagnostic_t *) parser->error_list.head; error != NULL; error = (pm_diagnostic_t *) error->node.next) { - VALUE location_argv[] = { - source, - LONG2FIX(error->location.start - parser->start), - LONG2FIX(error->location.end - error->location.start) - }; - VALUE location = rb_class_new_instance(3, location_argv, rb_cPrismLocation); - if (freeze) rb_obj_freeze(location); + for ( + const pm_diagnostic_t *error = (const pm_diagnostic_t *) parser->error_list.head; + error != NULL; + error = (const pm_diagnostic_t *) error->node.next + ) { + VALUE type = ID2SYM(rb_intern(pm_diagnostic_id_human(error->diag_id))); + VALUE message = rb_obj_freeze(rb_enc_str_new_cstr(error->message, encoding)); + VALUE location = PARSER_LOCATION_LOC(parser, source, freeze, error->location); VALUE level = Qnil; switch (error->level) { @@ -509,15 +518,8 @@ parser_errors(pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool fre rb_raise(rb_eRuntimeError, "Unknown level: %" PRIu8, error->level); } - VALUE message = rb_enc_str_new_cstr(error->message, encoding); - if (freeze) rb_obj_freeze(message); - - VALUE type = ID2SYM(rb_intern(pm_diagnostic_id_human(error->diag_id))); - VALUE error_argv[] = { type, message, location, level }; - - VALUE value = rb_class_new_instance(4, error_argv, rb_cPrismParseError); - if (freeze) rb_obj_freeze(value); - + VALUE argv[] = { type, message, location, level }; + VALUE value = rb_class_new_instance_freeze(4, argv, rb_cPrismParseError, freeze); rb_ary_push(errors, value); } @@ -529,19 +531,17 @@ parser_errors(pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool fre * Extract the warnings out of the parser into an array. */ static VALUE -parser_warnings(pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool freeze) { +parser_warnings(const pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool freeze) { VALUE warnings = rb_ary_new_capa(parser->warning_list.size); - pm_diagnostic_t *warning; - for (warning = (pm_diagnostic_t *) parser->warning_list.head; warning != NULL; warning = (pm_diagnostic_t *) warning->node.next) { - VALUE location_argv[] = { - source, - LONG2FIX(warning->location.start - parser->start), - LONG2FIX(warning->location.end - warning->location.start) - }; - - VALUE location = rb_class_new_instance(3, location_argv, rb_cPrismLocation); - if (freeze) rb_obj_freeze(location); + for ( + const pm_diagnostic_t *warning = (const pm_diagnostic_t *) parser->warning_list.head; + warning != NULL; + warning = (const pm_diagnostic_t *) warning->node.next + ) { + VALUE type = ID2SYM(rb_intern(pm_diagnostic_id_human(warning->diag_id))); + VALUE message = rb_obj_freeze(rb_enc_str_new_cstr(warning->message, encoding)); + VALUE location = PARSER_LOCATION_LOC(parser, source, freeze, warning->location); VALUE level = Qnil; switch (warning->level) { @@ -555,15 +555,8 @@ parser_warnings(pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool f rb_raise(rb_eRuntimeError, "Unknown level: %" PRIu8, warning->level); } - VALUE message = rb_enc_str_new_cstr(warning->message, encoding); - if (freeze) rb_obj_freeze(message); - - VALUE type = ID2SYM(rb_intern(pm_diagnostic_id_human(warning->diag_id))); - VALUE warning_argv[] = { type, message, location, level }; - - VALUE value = rb_class_new_instance(4, warning_argv, rb_cPrismParseWarning); - if (freeze) rb_obj_freeze(value); - + VALUE argv[] = { type, message, location, level }; + VALUE value = rb_class_new_instance_freeze(4, argv, rb_cPrismParseWarning, freeze); rb_ary_push(warnings, value); } @@ -575,7 +568,7 @@ parser_warnings(pm_parser_t *parser, rb_encoding *encoding, VALUE source, bool f * Create a new parse result from the given parser, value, encoding, and source. */ static VALUE -parse_result_create(VALUE class, pm_parser_t *parser, VALUE value, rb_encoding *encoding, VALUE source, bool freeze) { +parse_result_create(VALUE class, const pm_parser_t *parser, VALUE value, rb_encoding *encoding, VALUE source, bool freeze) { VALUE result_argv[] = { value, parser_comments(parser, source, freeze), @@ -586,10 +579,7 @@ parse_result_create(VALUE class, pm_parser_t *parser, VALUE value, rb_encoding * source }; - VALUE result = rb_class_new_instance(7, result_argv, class); - if (freeze) rb_obj_freeze(result); - - return result; + return rb_class_new_instance_freeze(7, result_argv, class, freeze); } /******************************************************************************/ diff --git a/prism/templates/lib/prism/serialize.rb.erb b/prism/templates/lib/prism/serialize.rb.erb index e8ac12830a..52821e0f7d 100644 --- a/prism/templates/lib/prism/serialize.rb.erb +++ b/prism/templates/lib/prism/serialize.rb.erb @@ -16,15 +16,41 @@ module Prism # strings. PATCH_VERSION = 0 - # Deserialize the AST represented by the given string into a parse result. - def self.load(input, serialized, freeze) + # Deserialize the dumped output from a request to parse or parse_file. + # + # The formatting of the source of this method is purposeful to illustrate + # the structure of the serialized data. + def self.load_parse(input, serialized, freeze) input = input.dup source = Source.for(input) - loader = Loader.new(source, serialized) - result = loader.load_result(freeze) - input.force_encoding(loader.encoding) + loader.load_header + encoding = loader.load_encoding + start_line = loader.load_varsint + offsets = loader.load_line_offsets(freeze) + + source.replace_start_line(start_line) + source.replace_offsets(offsets) + + comments = loader.load_comments(freeze) + magic_comments = loader.load_magic_comments(freeze) + data_loc = loader.load_optional_location_object(freeze) + errors = loader.load_errors(encoding, freeze) + warnings = loader.load_warnings(encoding, freeze) + cpool_base = loader.load_uint32 + cpool_size = loader.load_varuint + + constant_pool = ConstantPool.new(input, serialized, cpool_base, cpool_size) + + node = loader.load_node(constant_pool, encoding, freeze) + loader.load_constant_pool(constant_pool) + raise unless loader.eof? + + result = ParseResult.new(node, comments, magic_comments, data_loc, errors, warnings, source) + result.freeze if freeze + + input.force_encoding(encoding) # This is an extremely niche use-case where the file was marked as binary # but it contained UTF-8-encoded characters. In that case we will actually @@ -37,94 +63,231 @@ module Prism if freeze input.freeze - source.source.freeze - source.offsets.freeze - source.freeze + source.deep_freeze end result end - # Deserialize the tokens represented by the given string into a parse - # result. - def self.load_tokens(source, serialized, freeze) - Loader.new(source, serialized).load_tokens_result(freeze) + # Deserialize the dumped output from a request to lex or lex_file. + # + # The formatting of the source of this method is purposeful to illustrate + # the structure of the serialized data. + def self.load_lex(input, serialized, freeze) + source = Source.for(input) + loader = Loader.new(source, serialized) + + tokens = loader.load_tokens + encoding = loader.load_encoding + start_line = loader.load_varsint + offsets = loader.load_line_offsets(freeze) + + source.replace_start_line(start_line) + source.replace_offsets(offsets) + + comments = loader.load_comments(freeze) + magic_comments = loader.load_magic_comments(freeze) + data_loc = loader.load_optional_location_object(freeze) + errors = loader.load_errors(encoding, freeze) + warnings = loader.load_warnings(encoding, freeze) + raise unless loader.eof? + + result = LexResult.new(tokens, comments, magic_comments, data_loc, errors, warnings, source) + + tokens.each do |token| + token[0].value.force_encoding(encoding) + + if freeze + token[0].deep_freeze + token.freeze + end + end + + if freeze + source.deep_freeze + tokens.freeze + result.freeze + end + + result end - class Loader # :nodoc: - if RUBY_ENGINE == "truffleruby" - # StringIO is synchronized and that adds a high overhead on TruffleRuby. - class FastStringIO # :nodoc: - attr_accessor :pos - - def initialize(string) - @string = string - @pos = 0 - end + # Deserialize the dumped output from a request to parse_comments or + # parse_file_comments. + # + # The formatting of the source of this method is purposeful to illustrate + # the structure of the serialized data. + def self.load_parse_comments(input, serialized, freeze) + source = Source.for(input) + loader = Loader.new(source, serialized) - def getbyte - byte = @string.getbyte(@pos) - @pos += 1 - byte - end + loader.load_header + loader.load_encoding + start_line = loader.load_varsint - def read(n) - slice = @string.byteslice(@pos, n) - @pos += n - slice - end + source.replace_start_line(start_line) + + result = loader.load_comments(freeze) + raise unless loader.eof? + + source.deep_freeze if freeze + result + end + + # Deserialize the dumped output from a request to parse_lex or + # parse_lex_file. + # + # The formatting of the source of this method is purposeful to illustrate + # the structure of the serialized data. + def self.load_parse_lex(input, serialized, freeze) + source = Source.for(input) + loader = Loader.new(source, serialized) + + tokens = loader.load_tokens + loader.load_header + encoding = loader.load_encoding + start_line = loader.load_varsint + offsets = loader.load_line_offsets(freeze) + + source.replace_start_line(start_line) + source.replace_offsets(offsets) + + comments = loader.load_comments(freeze) + magic_comments = loader.load_magic_comments(freeze) + data_loc = loader.load_optional_location_object(freeze) + errors = loader.load_errors(encoding, freeze) + warnings = loader.load_warnings(encoding, freeze) + cpool_base = loader.load_uint32 + cpool_size = loader.load_varuint + + constant_pool = ConstantPool.new(input, serialized, cpool_base, cpool_size) + + node = loader.load_node(constant_pool, encoding, freeze) + loader.load_constant_pool(constant_pool) + raise unless loader.eof? + + value = [node, tokens] + result = ParseLexResult.new(value, comments, magic_comments, data_loc, errors, warnings, source) + + tokens.each do |token| + token[0].value.force_encoding(encoding) + + if freeze + token[0].deep_freeze + token.freeze + end + end - def eof? - @pos >= @string.bytesize + if freeze + source.deep_freeze + tokens.freeze + value.freeze + result.freeze + end + + result + end + + class ConstantPool # :nodoc: + attr_reader :size + + def initialize(input, serialized, base, size) + @input = input + @serialized = serialized + @base = base + @size = size + @pool = Array.new(size, nil) + end + + def get(index, encoding) + @pool[index] ||= + begin + offset = @base + index * 8 + start = @serialized.unpack1("L", offset: offset) + length = @serialized.unpack1("L", offset: offset + 4) + + if start.nobits?(1 << 31) + @input.byteslice(start, length).force_encoding(encoding).to_sym + else + @serialized.byteslice(start & ((1 << 31) - 1), length).force_encoding(encoding).to_sym + end end + end + end + + if RUBY_ENGINE == "truffleruby" + # StringIO is synchronized and that adds a high overhead on TruffleRuby. + class FastStringIO # :nodoc: + attr_accessor :pos + + def initialize(string) + @string = string + @pos = 0 + end + + def getbyte + byte = @string.getbyte(@pos) + @pos += 1 + byte + end + + def read(n) + slice = @string.byteslice(@pos, n) + @pos += n + slice + end + + def eof? + @pos >= @string.bytesize end - else - FastStringIO = ::StringIO end - private_constant :FastStringIO + else + FastStringIO = ::StringIO # :nodoc: + end - attr_reader :encoding, :input, :serialized, :io - attr_reader :constant_pool_offset, :constant_pool, :source - attr_reader :start_line + class Loader # :nodoc: + attr_reader :input, :io, :source def initialize(source, serialized) - @encoding = Encoding::UTF_8 - @input = source.source.dup raise unless serialized.encoding == Encoding::BINARY - @serialized = serialized @io = FastStringIO.new(serialized) + @source = source + define_load_node_lambdas if RUBY_ENGINE != "ruby" + end - @constant_pool_offset = nil - @constant_pool = nil + def eof? + io.getbyte + io.eof? + end - @source = source - define_load_node_lambdas unless RUBY_ENGINE == "ruby" + def load_constant_pool(constant_pool) + trailer = 0 + + constant_pool.size.times do |index| + start, length = io.read(8).unpack("L2") + trailer += length if start.anybits?(1 << 31) + end + + io.read(trailer) end def load_header raise "Invalid serialization" if io.read(5) != "PRISM" raise "Invalid serialization" if io.read(3).unpack("C3") != [MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION] - only_semantic_fields = io.getbyte - unless only_semantic_fields == 0 - raise "Invalid serialization (location fields must be included but are not)" - end + raise "Invalid serialization (location fields must be included but are not)" if io.getbyte != 0 end def load_encoding - @encoding = Encoding.find(io.read(load_varuint)) - @input = input.force_encoding(@encoding).freeze - @encoding - end - - def load_start_line - source.instance_variable_set(:@start_line, load_varsint) + encoding = Encoding.find(io.read(load_varuint)) + @input = input.force_encoding(encoding).freeze + encoding end def load_line_offsets(freeze) offsets = Array.new(load_varuint) { load_varuint } offsets.freeze if freeze - source.instance_variable_set(:@offsets, offsets) + offsets end def load_comments(freeze) @@ -187,13 +350,13 @@ module Prism end end - def load_errors(freeze) + def load_errors(encoding, freeze) errors = Array.new(load_varuint) do error = ParseError.new( DIAGNOSTIC_TYPES.fetch(load_varuint), - load_embedded_string, + load_embedded_string(encoding), load_location_object(freeze), load_error_level ) @@ -219,13 +382,13 @@ module Prism end end - def load_warnings(freeze) + def load_warnings(encoding, freeze) warnings = Array.new(load_varuint) do warning = ParseWarning.new( DIAGNOSTIC_TYPES.fetch(load_varuint), - load_embedded_string, + load_embedded_string(encoding), load_location_object(freeze), load_warning_level ) @@ -238,17 +401,7 @@ module Prism warnings end - def load_metadata(freeze) - [ - load_comments(freeze), - load_magic_comments(freeze), - load_optional_location_object(freeze), - load_errors(freeze), - load_warnings(freeze) - ] - end - - def load_tokens(freeze) + def load_tokens tokens = [] while (type = TOKEN_TYPES.fetch(load_varuint)) @@ -257,74 +410,14 @@ module Prism lex_state = load_varuint location = Location.new(@source, start, length) - location.freeze if freeze - - slice = location.slice - slice.freeze if freeze - - token = Token.new(@source, type, slice, location) - token.freeze if freeze + token = Token.new(@source, type, location.slice, location) tokens << [token, lex_state] end - tokens.freeze if freeze tokens end - def load_tokens_result(freeze) - tokens = load_tokens(false) - encoding = load_encoding - load_start_line - load_line_offsets(freeze) - comments, magic_comments, data_loc, errors, warnings = load_metadata(freeze) - - tokens.each do |token,| - token.value.force_encoding(encoding) - - if freeze - token.value.freeze - token.location.freeze - token.freeze - end - end - - raise "Expected to consume all bytes while deserializing" unless @io.eof? - result = LexResult.new(tokens, comments, magic_comments, data_loc, errors, warnings, @source) - - if freeze - tokens.each(&:freeze) - tokens.freeze - result.freeze - end - - result - end - - def load_nodes(freeze) - load_header - load_encoding - load_start_line - load_line_offsets(freeze) - - comments, magic_comments, data_loc, errors, warnings = load_metadata(freeze) - - @constant_pool_offset = load_uint32 - @constant_pool = Array.new(load_varuint, nil) - - [load_node(freeze), comments, magic_comments, data_loc, errors, warnings] - end - - def load_result(freeze) - node, comments, magic_comments, data_loc, errors, warnings = load_nodes(freeze) - result = ParseResult.new(node, comments, magic_comments, data_loc, errors, warnings, @source) - - result.freeze if freeze - result - end - - private - # variable-length integer using https://en.wikipedia.org/wiki/LEB128 # This is also what protobuf uses: https://protobuf.dev/programming-guides/encoding/#varints def load_varuint @@ -365,23 +458,23 @@ module Prism io.read(4).unpack1("L") end - def load_optional_node(freeze) + def load_optional_node(constant_pool, encoding, freeze) if io.getbyte != 0 io.pos -= 1 - load_node(freeze) + load_node(constant_pool, encoding, freeze) end end - def load_embedded_string + def load_embedded_string(encoding) io.read(load_varuint).force_encoding(encoding).freeze end - def load_string + def load_string(encoding) case (type = io.getbyte) when 1 input.byteslice(load_varuint, load_varuint).force_encoding(encoding).freeze when 2 - load_embedded_string + load_embedded_string(encoding) else raise "Unknown serialized string type: #{type}" end @@ -406,38 +499,18 @@ module Prism load_location_object(freeze) if io.getbyte != 0 end - def load_constant(index) - constant = constant_pool[index] - - unless constant - offset = constant_pool_offset + index * 8 - start = @serialized.unpack1("L", offset: offset) - length = @serialized.unpack1("L", offset: offset + 4) - - constant = - if start.nobits?(1 << 31) - input.byteslice(start, length).force_encoding(@encoding).to_sym - else - @serialized.byteslice(start & ((1 << 31) - 1), length).force_encoding(@encoding).to_sym - end - - constant_pool[index] = constant - end - - constant - end - - def load_required_constant - load_constant(load_varuint - 1) + def load_constant(constant_pool, encoding) + index = load_varuint + constant_pool.get(index - 1, encoding) end - def load_optional_constant + def load_optional_constant(constant_pool, encoding) index = load_varuint - load_constant(index - 1) if index != 0 + constant_pool.get(index - 1, encoding) if index != 0 end if RUBY_ENGINE == "ruby" - def load_node(freeze) + def load_node(constant_pool, encoding, freeze) type = io.getbyte node_id = load_varuint location = load_location(freeze) @@ -449,13 +522,13 @@ module Prism <%- end -%> <%= node.name %>.new(<%= ["source", "node_id", "location", "load_varuint", *node.fields.map { |field| case field - when Prism::Template::NodeField then "load_node(freeze)" - when Prism::Template::OptionalNodeField then "load_optional_node(freeze)" - when Prism::Template::StringField then "load_string" - when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node(freeze) }.tap { |nodes| nodes.freeze if freeze }" - when Prism::Template::ConstantField then "load_required_constant" - when Prism::Template::OptionalConstantField then "load_optional_constant" - when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_required_constant }.tap { |constants| constants.freeze if freeze }" + when Prism::Template::NodeField then "load_node(constant_pool, encoding, freeze)" + when Prism::Template::OptionalNodeField then "load_optional_node(constant_pool, encoding, freeze)" + when Prism::Template::StringField then "load_string(encoding)" + when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node(constant_pool, encoding, freeze) }.tap { |nodes| nodes.freeze if freeze }" + when Prism::Template::ConstantField then "load_constant(constant_pool, encoding)" + when Prism::Template::OptionalConstantField then "load_optional_constant(constant_pool, encoding)" + when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_constant(constant_pool, encoding) }.tap { |constants| constants.freeze if freeze }" when Prism::Template::LocationField then "load_location(freeze)" when Prism::Template::OptionalLocationField then "load_optional_location(freeze)" when Prism::Template::UInt8Field then "io.getbyte" @@ -472,16 +545,15 @@ module Prism value end else - def load_node(freeze) - type = io.getbyte - @load_node_lambdas[type].call(freeze) + def load_node(constant_pool, encoding, freeze) + @load_node_lambdas[io.getbyte].call(constant_pool, encoding, freeze) end def define_load_node_lambdas @load_node_lambdas = [ nil, <%- nodes.each do |node| -%> - -> (freeze) { + -> (constant_pool, encoding, freeze) { node_id = load_varuint location = load_location(freeze) <%- if node.needs_serialized_length? -%> @@ -489,13 +561,13 @@ module Prism <%- end -%> value = <%= node.name %>.new(<%= ["source", "node_id", "location", "load_varuint", *node.fields.map { |field| case field - when Prism::Template::NodeField then "load_node(freeze)" - when Prism::Template::OptionalNodeField then "load_optional_node(freeze)" - when Prism::Template::StringField then "load_string" - when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node(freeze) }" - when Prism::Template::ConstantField then "load_required_constant" - when Prism::Template::OptionalConstantField then "load_optional_constant" - when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_required_constant }" + when Prism::Template::NodeField then "load_node(constant_pool, encoding, freeze)" + when Prism::Template::OptionalNodeField then "load_optional_node(constant_pool, encoding, freeze)" + when Prism::Template::StringField then "load_string(encoding)" + when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node(constant_pool, encoding, freeze) }" + when Prism::Template::ConstantField then "load_constant(constant_pool, encoding)" + when Prism::Template::OptionalConstantField then "load_optional_constant(constant_pool, encoding)" + when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_constant(constant_pool, encoding) }" when Prism::Template::LocationField then "load_location(freeze)" when Prism::Template::OptionalLocationField then "load_optional_location(freeze)" when Prism::Template::UInt8Field then "io.getbyte" @@ -522,6 +594,9 @@ module Prism <%- end -%> ] - private_constant :TOKEN_TYPES + private_constant :MAJOR_VERSION, :MINOR_VERSION, :PATCH_VERSION + private_constant :ConstantPool, :FastStringIO, :Loader, :TOKEN_TYPES end + + private_constant :Serialize end |