diff --git a/benchmark/benchmark_prototype_rb.rb b/benchmark/benchmark_prototype_rb.rb new file mode 100644 index 0000000000..2d4bfa4892 --- /dev/null +++ b/benchmark/benchmark_prototype_rb.rb @@ -0,0 +1,36 @@ +# frozen_string_literal: true + +require "benchmark/ips" +require_relative "../lib/rbs" + +# Collect Ruby source files to parse +sources = Dir.glob(File.join(__dir__, "../lib/**/*.rb")).map do |path| + [path, File.read(path)] +end + +puts "Benchmarking prototype generation (#{sources.size} files, #{sources.sum { |_, s| s.size }} bytes total)" +puts + +Benchmark.ips do |x| + x.report("RB: RubyVM::AbstractSyntaxTree") do + ENV.delete("RBS_RUBY_PARSER") + sources.each do |_path, source| + parser = RBS::Prototype::RB.new + parser.parse(source) + parser.decls + end + end + + x.report("RB: Prism") do + ENV["RBS_RUBY_PARSER"] = "prism" + sources.each do |_path, source| + parser = RBS::Prototype::RB.new + parser.parse(source) + parser.decls + end + end + + x.compare! +ensure + ENV.delete("RBS_RUBY_PARSER") +end diff --git a/lib/rbs.rb b/lib/rbs.rb index d5388d849e..a6f7ec91d9 100644 --- a/lib/rbs.rb +++ b/lib/rbs.rb @@ -51,11 +51,16 @@ require "rbs/resolver/type_name_resolver" require "rbs/ast/comment" require "rbs/writer" -require "rbs/prototype/helpers" +require "rbs/prototype/comment_parser" +require "rbs/prototype/ruby_vm_helpers" +require "rbs/prototype/node_usage" require "rbs/prototype/rbi" +require "rbs/prototype/rbi/ruby_vm" +require "rbs/prototype/rbi/prism" require "rbs/prototype/rb" +require "rbs/prototype/rb/ruby_vm" +require "rbs/prototype/rb/prism" require "rbs/prototype/runtime" -require "rbs/prototype/node_usage" require "rbs/environment_walker" require "rbs/vendorer" require "rbs/validator" diff --git a/lib/rbs/prototype/comment_parser.rb b/lib/rbs/prototype/comment_parser.rb new file mode 100644 index 0000000000..97299f5faf --- /dev/null +++ b/lib/rbs/prototype/comment_parser.rb @@ -0,0 +1,68 @@ +# frozen_string_literal: true + +module RBS + module Prototype + module CommentParser + # Build a line-number-keyed hash of comments from a Prism::ParseResult or + # an array of Prism comment objects. + def build_comments_prism(comments, include_trailing:) + comments.each_with_object({}) do |comment, hash| #$ Hash[Integer, AST::Comment] + next unless comment.is_a?(Prism::InlineComment) + next if comment.trailing? && !include_trailing + + line = comment.location.start_line + body = "#{comment.location.slice}\n" + body = body[2..-1] or raise + body = "\n" if body.empty? + + comment = AST::Comment.new(string: body, location: nil) + if prev_comment = hash.delete(line - 1) + hash[line] = AST::Comment.new(string: prev_comment.string + comment.string, location: nil) + else + hash[line] = comment + end + end + end + + # Parse comments from a Ruby source string. Uses Prism on Ruby >= 3.3, + # falls back to Ripper on older Rubies. + if RUBY_VERSION >= "3.3" + def parse_comments(string, include_trailing:) + build_comments_prism( + Prism.parse_comments(string, version: "current"), # steep:ignore UnexpectedKeywordArgument + include_trailing: include_trailing + ) + end + else + require "ripper" + + def parse_comments(string, include_trailing:) + Ripper.lex(string).yield_self do |tokens| + code_lines = {} #: Hash[Integer, bool] + tokens.each.with_object({}) do |token, hash| #$ Hash[Integer, AST::Comment] + case token[1] + when :on_sp, :on_ignored_nl + # skip + when :on_comment + line = token[0][0] + next if code_lines[line] && !include_trailing + body = token[2][2..-1] or raise + + body = "\n" if body.empty? + + comment = AST::Comment.new(string: body, location: nil) + if prev_comment = hash.delete(line - 1) + hash[line] = AST::Comment.new(string: prev_comment.string + comment.string, location: nil) + else + hash[line] = comment + end + else + code_lines[token[0][0]] = true + end + end + end + end + end + end + end +end diff --git a/lib/rbs/prototype/node_usage.rb b/lib/rbs/prototype/node_usage.rb index 2c5f07e0ba..4cd166a5a5 100644 --- a/lib/rbs/prototype/node_usage.rb +++ b/lib/rbs/prototype/node_usage.rb @@ -3,7 +3,7 @@ module RBS module Prototype class NodeUsage - include Helpers + include RubyVMHelpers attr_reader :conditional_nodes diff --git a/lib/rbs/prototype/rb.rb b/lib/rbs/prototype/rb.rb index 8e3562db46..fe47cc6fdb 100644 --- a/lib/rbs/prototype/rb.rb +++ b/lib/rbs/prototype/rb.rb @@ -2,8 +2,14 @@ module RBS module Prototype - class RB - include Helpers + module RB + def self.new + if ENV['RBS_RUBY_PARSER'] == 'prism' + RB::Prism.new + else + RB::RubyVM.new + end + end class Context < Struct.new(:module_function, :singleton, :namespace, :in_def, keyword_init: true) # @implements Context @@ -39,778 +45,153 @@ def update(module_function: self.module_function, singleton: self.singleton, in_ end end - attr_reader :source_decls - attr_reader :toplevel_members + class Base + include CommentParser - def initialize - @source_decls = [] - end + attr_reader :source_decls + attr_reader :toplevel_members - def decls - # @type var decls: Array[AST::Declarations::t] - decls = [] - - # @type var top_decls: Array[AST::Declarations::t] - # @type var top_members: Array[AST::Members::t] - top_decls, top_members = _ = source_decls.partition {|decl| decl.is_a?(AST::Declarations::Base) } - - decls.push(*top_decls) - - unless top_members.empty? - top = AST::Declarations::Class.new( - name: TypeName.new(name: :Object, namespace: Namespace.empty), - super_class: nil, - members: top_members, - annotations: [], - comment: nil, - location: nil, - type_params: [] - ) - decls << top + def initialize + @source_decls = [] end - decls - end - - def parse(string) - # @type var comments: Hash[Integer, AST::Comment] - comments = parse_comments(string, include_trailing: false) - - process RubyVM::AbstractSyntaxTree.parse(string), decls: source_decls, comments: comments, context: Context.initial - end - - def process(node, decls:, comments:, context:) - case node.type - when :CLASS - class_name, super_class_node, *class_body = node.children - super_class_name = const_to_name(super_class_node, context: context) - super_class = - if super_class_name - AST::Declarations::Class::Super.new(name: super_class_name, args: [], location: nil) - else - # Give up detect super class e.g. `class Foo < Struct.new(:bar)` - nil - end - kls = AST::Declarations::Class.new( - name: const_to_name!(class_name), - super_class: super_class, - type_params: [], - members: [], - annotations: [], - location: nil, - comment: comments[node.first_lineno - 1] - ) - - decls.push kls - - new_ctx = context.enter_namespace(kls.name.to_namespace) - each_node class_body do |child| - process child, decls: kls.members, comments: comments, context: new_ctx - end - remove_unnecessary_accessibility_methods! kls.members - sort_members! kls.members - - when :MODULE - module_name, *module_body = node.children - - mod = AST::Declarations::Module.new( - name: const_to_name!(module_name), - type_params: [], - self_types: [], - members: [], - annotations: [], - location: nil, - comment: comments[node.first_lineno - 1] - ) - - decls.push mod - - new_ctx = context.enter_namespace(mod.name.to_namespace) - each_node module_body do |child| - process child, decls: mod.members, comments: comments, context: new_ctx - end - remove_unnecessary_accessibility_methods! mod.members - sort_members! mod.members - - when :SCLASS - this, body = node.children - - if this.type != :SELF - RBS.logger.warn "`class <<` syntax with not-self may be compiled to incorrect code: #{this}" - end - - accessibility = current_accessibility(decls) + def decls + # @type var decls: Array[AST::Declarations::t] + decls = [] - ctx = Context.initial.tap { |ctx| ctx.singleton = true } - process_children(body, decls: decls, comments: comments, context: ctx) + # @type var top_decls: Array[AST::Declarations::t] + # @type var top_members: Array[AST::Members::t] + top_decls, top_members = _ = source_decls.partition {|decl| decl.is_a?(AST::Declarations::Base) } - decls << accessibility + decls.push(*top_decls) - when :DEFN, :DEFS - # @type var kind: Context::method_kind - - if node.type == :DEFN - def_name, def_body = node.children - kind = context.method_kind - else - _, def_name, def_body = node.children - kind = :singleton - end - - types = [ - MethodType.new( - type_params: [], - type: function_type_from_body(def_body, def_name), - block: block_from_body(def_body), - location: nil - ) - ] - - member = AST::Members::MethodDefinition.new( - name: def_name, - location: nil, - annotations: [], - overloads: types.map {|type| AST::Members::MethodDefinition::Overload.new(annotations: [], method_type: type )}, - kind: kind, - comment: comments[node.first_lineno - 1], - overloading: false, - visibility: nil - ) - - decls.push member unless decls.include?(member) - - new_ctx = context.update(singleton: kind == :singleton, in_def: true) - each_node def_body.children do |child| - process child, decls: decls, comments: comments, context: new_ctx - end - - when :ALIAS - new_name, old_name = node.children.map { |c| literal_to_symbol(c) } - member = AST::Members::Alias.new( - new_name: new_name, - old_name: old_name, - kind: context.singleton ? :singleton : :instance, - annotations: [], - location: nil, - comment: comments[node.first_lineno - 1], - ) - decls.push member unless decls.include?(member) - - when :FCALL, :VCALL - # Inside method definition cannot reach here. - args = node.children[1]&.children || [] - - case node.children[0] - when :include - args.each do |arg| - if (name = const_to_name(arg, context: context)) - klass = context.singleton ? AST::Members::Extend : AST::Members::Include - decls << klass.new( - name: name, - args: [], - annotations: [], - location: nil, - comment: comments[node.first_lineno - 1] - ) - end - end - when :prepend - args.each do |arg| - if (name = const_to_name(arg, context: context)) - decls << AST::Members::Prepend.new( - name: name, - args: [], - annotations: [], - location: nil, - comment: comments[node.first_lineno - 1] - ) - end - end - when :extend - args.each do |arg| - if (name = const_to_name(arg, context: context)) - decls << AST::Members::Extend.new( - name: name, - args: [], - annotations: [], - location: nil, - comment: comments[node.first_lineno - 1] - ) - end - end - when :attr_reader - args.each do |arg| - if arg && (name = literal_to_symbol(arg)) - decls << AST::Members::AttrReader.new( - name: name, - ivar_name: nil, - type: Types::Bases::Any.new(location: nil), - kind: context.attribute_kind, - location: nil, - comment: comments[node.first_lineno - 1], - annotations: [] - ) - end - end - when :attr_accessor - args.each do |arg| - if arg && (name = literal_to_symbol(arg)) - decls << AST::Members::AttrAccessor.new( - name: name, - ivar_name: nil, - type: Types::Bases::Any.new(location: nil), - kind: context.attribute_kind, - location: nil, - comment: comments[node.first_lineno - 1], - annotations: [] - ) - end - end - when :attr_writer - args.each do |arg| - if arg && (name = literal_to_symbol(arg)) - decls << AST::Members::AttrWriter.new( - name: name, - ivar_name: nil, - type: Types::Bases::Any.new(location: nil), - kind: context.attribute_kind, - location: nil, - comment: comments[node.first_lineno - 1], - annotations: [] - ) - end - end - when :alias_method - if args[0] && args[1] && (new_name = literal_to_symbol(args[0])) && (old_name = literal_to_symbol(args[1])) - decls << AST::Members::Alias.new( - new_name: new_name, - old_name: old_name, - kind: context.singleton ? :singleton : :instance, - annotations: [], - location: nil, - comment: comments[node.first_lineno - 1], - ) - end - when :module_function - if args.empty? - context.module_function = true - else - module_func_context = context.update(module_function: true) - args.each do |arg| - if arg && (name = literal_to_symbol(arg)) - if (i, defn = find_def_index_by_name(decls, name)) - if defn.is_a?(AST::Members::MethodDefinition) - decls[i] = defn.update(kind: :singleton_instance) - end - end - elsif arg - process arg, decls: decls, comments: comments, context: module_func_context - end - end - end - when :public, :private - accessibility = __send__(node.children[0]) - if args.empty? - decls << accessibility - else - args.each do |arg| - if arg && (name = literal_to_symbol(arg)) - if (i, _ = find_def_index_by_name(decls, name)) - current = current_accessibility(decls, i) - if current != accessibility - decls.insert(i + 1, current) - decls.insert(i, accessibility) - end - end - end - end - - # For `private def foo` syntax - current = current_accessibility(decls) - decls << accessibility - process_children(node, decls: decls, comments: comments, context: context) - decls << current - end - else - process_children(node, decls: decls, comments: comments, context: context) - end - - when :ITER - # ignore - - when :CDECL - const_name = case - when node.children[0].is_a?(Symbol) - TypeName.new(name: node.children[0], namespace: Namespace.empty) - else - const_to_name!(node.children[0], context: context) - end - - value_node = node.children.last - type = if value_node.nil? || value_node.type == :SELF - # Give up type prediction when node is MASGN or SELF. - Types::Bases::Any.new(location: nil) - else - literal_to_type(value_node) - end - decls << AST::Declarations::Constant.new( - name: const_name, - type: type, - location: nil, - comment: comments[node.first_lineno - 1], - annotations: [] - ) - - when :IASGN - case [context.singleton, context.in_def] - when [true, true], [false, false] - member = AST::Members::ClassInstanceVariable.new( - name: node.children.first, - type: Types::Bases::Any.new(location: nil), + unless top_members.empty? + top = AST::Declarations::Class.new( + name: TypeName.new(name: :Object, namespace: Namespace.empty), + super_class: nil, + members: top_members, + annotations: [], + comment: nil, location: nil, - comment: comments[node.first_lineno - 1] + type_params: [] ) - when [false, true] - member = AST::Members::InstanceVariable.new( - name: node.children.first, - type: Types::Bases::Any.new(location: nil), - location: nil, - comment: comments[node.first_lineno - 1] - ) - when [true, false] - # The variable is for the singleton class of the class object. - # RBS does not have a way to represent it. So we ignore it. - else - raise 'unreachable' + decls << top end - decls.push member if member && !decls.include?(member) - - when :CVASGN - member = AST::Members::ClassVariable.new( - name: node.children.first, - type: Types::Bases::Any.new(location: nil), - location: nil, - comment: comments[node.first_lineno - 1] - ) - decls.push member unless decls.include?(member) - else - process_children(node, decls: decls, comments: comments, context: context) - end - end - - def process_children(node, decls:, comments:, context:) - each_child node do |child| - process child, decls: decls, comments: comments, context: context - end - end - - def const_to_name!(node, context: nil) - case node.type - when :CONST - TypeName.new(name: node.children[0], namespace: Namespace.empty) - when :COLON2 - if node.children[0] - namespace = const_to_name!(node.children[0], context: context).to_namespace - else - namespace = Namespace.empty - end - - TypeName.new(name: node.children[1], namespace: namespace) - when :COLON3 - TypeName.new(name: node.children[0], namespace: Namespace.root) - when :SELF - raise if context.nil? - - context.namespace.to_type_name - else - raise - end - end - - def const_to_name(node, context:) - if node - case node.type - when :SELF - context.namespace.to_type_name - when :CONST, :COLON2, :COLON3 - const_to_name!(node) rescue nil - end - end - end - - def literal_to_symbol(node) - case node.type - when :SYM - node.children[0] - when :LIT - node.children[0] if node.children[0].is_a?(Symbol) - when :STR - node.children[0].to_sym - end - end - - def function_type_from_body(node, def_name) - table_node, args_node, *_ = node.children - - pre_num, _pre_init, opt, _first_post, post_num, _post_init, rest, kw, kwrest, _block = args_from_node(args_node) - - return_type = if def_name == :initialize - Types::Bases::Void.new(location: nil) - else - function_return_type_from_body(node) - end - - fun = Types::Function.empty(return_type) - - table_node.take(pre_num).each do |name| - fun.required_positionals << Types::Function::Param.new(name: name, type: untyped) + decls end - while opt&.type == :OPT_ARG - lvasgn, opt = opt.children - name = lvasgn.children[0] - fun.optional_positionals << Types::Function::Param.new( - name: name, - type: param_type(lvasgn.children[1]) - ) - end + def types_to_union_type(types) + return untyped if types.empty? - if rest - rest_name = rest == :* ? nil : rest # `def f(...)` syntax has `*` name - fun = fun.update(rest_positionals: Types::Function::Param.new(name: rest_name, type: untyped)) - end - - table_node.drop(fun.required_positionals.size + fun.optional_positionals.size + (fun.rest_positionals ? 1 : 0)).take(post_num).each do |name| - fun.trailing_positionals << Types::Function::Param.new(name: name, type: untyped) - end - - while kw - lvasgn, kw = kw.children - name, value = lvasgn.children - - case value - when nil, :NODE_SPECIAL_REQUIRED_KEYWORD - fun.required_keywords[name] = Types::Function::Param.new(name: nil, type: untyped) - when RubyVM::AbstractSyntaxTree::Node - fun.optional_keywords[name] = Types::Function::Param.new(name: nil, type: param_type(value)) - else - raise "Unexpected keyword arg value: #{value}" + uniq = types.uniq + if uniq.size == 1 + return uniq.first || raise end - end - if kwrest && kwrest.children.any? - kwrest_name = kwrest.children[0] #: Symbol? - kwrest_name = nil if kwrest_name == :** # `def f(...)` syntax has `**` name - fun = fun.update(rest_keywords: Types::Function::Param.new(name: kwrest_name, type: untyped)) - end - - fun - end - - def function_return_type_from_body(node) - body = node.children[2] - body_type(body) - end - - def body_type(node) - return Types::Bases::Nil.new(location: nil) unless node - - case node.type - when :IF, :UNLESS - if_unless_type(node) - when :BLOCK - block_type(node) - else - literal_to_type(node) + Types::Union.new(types: uniq, location: nil) end - end - - def if_unless_type(node) - raise unless node.type == :IF || node.type == :UNLESS - _exp_node, true_node, false_node = node.children - types_to_union_type([body_type(true_node), body_type(false_node)]) - end - - def block_type(node) - raise unless node.type == :BLOCK - - return_stmts = any_node?(node) do |n| - n.type == :RETURN - end&.map do |return_node| - returned_value = return_node.children[0] - returned_value ? literal_to_type(returned_value) : Types::Bases::Nil.new(location: nil) - end || [] - last_node = node.children.last - last_evaluated = last_node ? literal_to_type(last_node) : Types::Bases::Nil.new(location: nil) - types_to_union_type([*return_stmts, last_evaluated]) - end + def range_element_type(types) + types = types.reject { |t| t == untyped } + return untyped if types.empty? - def literal_to_type(node) - case node.type - when :STR - lit = node.children[0] - if lit.ascii_only? - Types::Literal.new(literal: lit, location: nil) - else - BuiltinNames::String.instance_type - end - when :DSTR, :XSTR - BuiltinNames::String.instance_type - when :SYM - lit = node.children[0] - if lit.to_s.ascii_only? - Types::Literal.new(literal: lit, location: nil) - else - BuiltinNames::Symbol.instance_type - end - when :DSYM - BuiltinNames::Symbol.instance_type - when :DREGX, :REGX - BuiltinNames::Regexp.instance_type - when :TRUE - Types::Literal.new(literal: true, location: nil) - when :FALSE - Types::Literal.new(literal: false, location: nil) - when :NIL - Types::Bases::Nil.new(location: nil) - when :INTEGER - Types::Literal.new(literal: node.children[0], location: nil) - when :FLOAT - BuiltinNames::Float.instance_type - when :RATIONAL, :IMAGINARY - lit = node.children[0] - type_name = TypeName.new(name: lit.class.name.to_sym, namespace: Namespace.root) - Types::ClassInstance.new(name: type_name, args: [], location: nil) - when :LIT - lit = node.children[0] - case lit - when Symbol - if lit.to_s.ascii_only? - Types::Literal.new(literal: lit, location: nil) + types = types.map do |t| + if t.is_a?(Types::Literal) + type_name = TypeName.new(name: t.literal.class.name&.to_sym || raise, namespace: Namespace.root) + Types::ClassInstance.new(name: type_name, args: [], location: nil) else - BuiltinNames::Symbol.instance_type + t end - when Integer - Types::Literal.new(literal: lit, location: nil) - when String - # For Ruby <=3.3 which generates `LIT` node for string literals inside Hash literal. - # "a" => STR node - # { "a" => nil } => LIT node - Types::Literal.new(literal: lit, location: nil) - else - type_name = TypeName.new(name: lit.class.name.to_sym, namespace: Namespace.root) - Types::ClassInstance.new(name: type_name, args: [], location: nil) - end - when :ZLIST, :ZARRAY - BuiltinNames::Array.instance_type(untyped) - when :LIST, :ARRAY - elem_types = node.children.compact.map { |e| literal_to_type(e) } - t = types_to_union_type(elem_types) - BuiltinNames::Array.instance_type(t) - when :DOT2, :DOT3 - types = node.children.map { |c| literal_to_type(c) } - type = range_element_type(types) - BuiltinNames::Range.instance_type(type) - when :HASH - list = node.children[0] - if list - children = list.children - children.pop - else - children = [] #: Array[untyped] - end - - key_types = [] #: Array[Types::t] - value_types = [] #: Array[Types::t] - children.each_slice(2) do |k, v| - if k - key_types << literal_to_type(k) - value_types << literal_to_type(v) - else - key_types << untyped - value_types << untyped - end - end + end.uniq - if !key_types.empty? && key_types.all? { |t| t.is_a?(Types::Literal) } - fields = key_types.map {|t| - t.is_a?(Types::Literal) or raise - t.literal - }.zip(value_types).to_h #: Hash[Types::Literal::literal, Types::t] - Types::Record.new(fields: fields, location: nil) - else - key_type = types_to_union_type(key_types) - value_type = types_to_union_type(value_types) - BuiltinNames::Hash.instance_type(key_type, value_type) - end - when :SELF - Types::Bases::Self.new(location: nil) - when :CALL - receiver, method_name, * = node.children - case method_name - when :freeze, :tap, :itself, :dup, :clone, :taint, :untaint, :extend - literal_to_type(receiver) + if types.size == 1 + types.first or raise else untyped end - else - untyped end - end - - def types_to_union_type(types) - return untyped if types.empty? - uniq = types.uniq - if uniq.size == 1 - return uniq.first || raise + def untyped + @untyped ||= Types::Bases::Any.new(location: nil) end - Types::Union.new(types: uniq, location: nil) - end - - def range_element_type(types) - types = types.reject { |t| t == untyped } - return untyped if types.empty? - - types = types.map do |t| - if t.is_a?(Types::Literal) - type_name = TypeName.new(name: t.literal.class.name&.to_sym || raise, namespace: Namespace.root) - Types::ClassInstance.new(name: type_name, args: [], location: nil) - else - t - end - end.uniq + def private + @private ||= AST::Members::Private.new(location: nil) + end - if types.size == 1 - types.first or raise - else - untyped + def public + @public ||= AST::Members::Public.new(location: nil) end - end - def param_type(node, default: Types::Bases::Any.new(location: nil)) - case node.type - when :INTEGER - BuiltinNames::Integer.instance_type - when :FLOAT - BuiltinNames::Float.instance_type - when :RATIONAL - Types::ClassInstance.new(name: TypeName.parse("::Rational"), args: [], location: nil) - when :IMAGINARY - Types::ClassInstance.new(name: TypeName.parse("::Complex"), args: [], location: nil) - when :LIT - case node.children[0] - when Symbol - BuiltinNames::Symbol.instance_type - when Integer - BuiltinNames::Integer.instance_type - when Float - BuiltinNames::Float.instance_type + def current_accessibility(decls, index = decls.size) + slice = decls.slice(0, index) or raise + idx = slice.rindex { |decl| decl == private || decl == public } + if idx + _ = decls[idx] else - default + public end - when :SYM - BuiltinNames::Symbol.instance_type - when :STR, :DSTR - BuiltinNames::String.instance_type - when :NIL - # This type is technical non-sense, but may help practically. - Types::Optional.new( - type: Types::Bases::Any.new(location: nil), - location: nil - ) - when :TRUE, :FALSE - Types::Bases::Bool.new(location: nil) - when :ARRAY, :LIST - BuiltinNames::Array.instance_type(default) - when :HASH - BuiltinNames::Hash.instance_type(default, default) - else - default end - end - # backward compatible - alias node_type param_type - - def private - @private ||= AST::Members::Private.new(location: nil) - end + def remove_unnecessary_accessibility_methods!(decls) + # @type var current: decl + current = public + idx = 0 - def public - @public ||= AST::Members::Public.new(location: nil) - end - - def current_accessibility(decls, index = decls.size) - slice = decls.slice(0, index) or raise - idx = slice.rindex { |decl| decl == private || decl == public } - if idx - _ = decls[idx] - else - public - end - end - - def remove_unnecessary_accessibility_methods!(decls) - # @type var current: decl - current = public - idx = 0 + loop do + decl = decls[idx] or break + if current == decl + decls.delete_at(idx) + next + end - loop do - decl = decls[idx] or break - if current == decl - decls.delete_at(idx) - next - end + if 0 < idx && is_accessibility?(decls[idx - 1]) && is_accessibility?(decl) + decls.delete_at(idx - 1) + idx -= 1 + current = current_accessibility(decls, idx) + next + end - if 0 < idx && is_accessibility?(decls[idx - 1]) && is_accessibility?(decl) - decls.delete_at(idx - 1) - idx -= 1 - current = current_accessibility(decls, idx) - next + current = decl if is_accessibility?(decl) + idx += 1 end - current = decl if is_accessibility?(decl) - idx += 1 + decls.pop while decls.last && is_accessibility?(decls.last || raise) end - decls.pop while decls.last && is_accessibility?(decls.last || raise) - end + def is_accessibility?(decl) + decl == public || decl == private + end - def is_accessibility?(decl) - decl == public || decl == private - end + def find_def_index_by_name(decls, name) + index = decls.find_index do |decl| + case decl + when AST::Members::MethodDefinition, AST::Members::AttrReader + decl.name == name + when AST::Members::AttrWriter + :"#{decl.name}=" == name + end + end - def find_def_index_by_name(decls, name) - index = decls.find_index do |decl| - case decl - when AST::Members::MethodDefinition, AST::Members::AttrReader - decl.name == name - when AST::Members::AttrWriter - :"#{decl.name}=" == name + if index + [ + index, + _ = decls[index] + ] end end - if index - [ - index, - _ = decls[index] - ] + def sort_members!(decls) + i = 0 + orders = { + AST::Members::ClassVariable => -3, + AST::Members::ClassInstanceVariable => -2, + AST::Members::InstanceVariable => -1, + } #: Hash[Class, Integer] + decls.sort_by! { |decl| [orders.fetch(decl.class, 0), i += 1] } end end - - def sort_members!(decls) - i = 0 - orders = { - AST::Members::ClassVariable => -3, - AST::Members::ClassInstanceVariable => -2, - AST::Members::InstanceVariable => -1, - } #: Hash[Class, Integer] - decls.sort_by! { |decl| [orders.fetch(decl.class, 0), i += 1] } - end end end end diff --git a/lib/rbs/prototype/rb/prism.rb b/lib/rbs/prototype/rb/prism.rb new file mode 100644 index 0000000000..6778dbd3d2 --- /dev/null +++ b/lib/rbs/prototype/rb/prism.rb @@ -0,0 +1,819 @@ +# frozen_string_literal: true + +module RBS + module Prototype + module RB + class Prism < Base + class NodeUsage + attr_reader :conditional_nodes + + def initialize(node) + @conditional_nodes = Set[].compare_by_identity + calculate(node, conditional: false) + end + + def each_conditional_node(&block) + if block + conditional_nodes.each(&block) + else + conditional_nodes.each + end + end + + private + + def calculate(node, conditional:) + conditional_nodes << node if conditional + + case node.type + when :if_node, :unless_node + calculate(node.predicate, conditional: true) + + if node.type == :if_node + calculate_statements(node.statements, conditional: conditional) + case node.subsequent&.type + when :else_node + calculate_statements(node.subsequent.statements, conditional: conditional) + when :if_node + calculate(node.subsequent, conditional: conditional) + end + else + calculate_statements(node.statements, conditional: conditional) + calculate_statements(node.else_clause&.statements, conditional: conditional) + end + + when :and_node, :or_node + calculate(node.left, conditional: true) + calculate(node.right, conditional: conditional) + + when :call_node + if node.safe_navigation? && node.receiver + calculate(node.receiver, conditional: true) + node.arguments&.arguments&.each { |a| calculate(a, conditional: false) } + else + node.each_child_node { |c| calculate(c, conditional: false) } + end + + when :while_node, :until_node + calculate(node.predicate, conditional: true) + calculate_statements(node.statements, conditional: false) + + when :local_variable_or_write_node, :local_variable_and_write_node + conditional_nodes << node + calculate(node.value, conditional: conditional) + + when :local_variable_write_node, :instance_variable_write_node, :global_variable_write_node + calculate(node.value, conditional: conditional) + + when :multi_write_node + node.lefts.each { |t| calculate(t, conditional: conditional) } + calculate(node.value, conditional: conditional) + + when :constant_write_node + calculate(node.value, conditional: conditional) + + when :constant_path_write_node + calculate(node.target, conditional: false) + calculate(node.value, conditional: conditional) + + when :statements_node + if node.body.size > 0 + node.body[0...-1].each { |n| calculate(n, conditional: false) } + calculate(node.body.last, conditional: conditional) if node.body.last + end + + when :case_match_node + node.conditions.each do |cond| + calculate(cond.pattern, conditional: true) if cond.respond_to?(:pattern) + calculate_statements(cond.statements, conditional: conditional) if cond.respond_to?(:statements) + end + calculate_statements(node.else_clause&.statements, conditional: conditional) + + when :case_node + node.conditions.each do |when_node| + when_node.conditions.each { |c| calculate(c, conditional: true) } + calculate_statements(when_node.statements, conditional: conditional) + end + calculate_statements(node.else_clause&.statements, conditional: conditional) + + when :def_node + calculate(node.body, conditional: conditional) if node.body + + else + node.each_child_node { |c| calculate(c, conditional: false) } + end + end + + def calculate_statements(node, conditional:) + return unless node + calculate(node, conditional: conditional) + end + end + + # Pre-scan a method body in a single pass, collecting yields, + # block_given? calls, and return nodes so we don't walk the tree + # multiple times. + BodyInfo = Struct.new(:yields, :has_block_given, :returns, keyword_init: true) + + private_constant :NodeUsage, :BodyInfo + + def parse(string) + result = ::Prism.parse(string) + comments = build_comments_prism(result.comments, include_trailing: false) + process(result.value, decls: source_decls, comments: comments, context: Context.initial) + end + + def block_from_def(node, body_info = build_body_info(node.body)) + params = node.parameters + body = node.body + + block_param = params&.block + forwarding = params&.keyword_rest&.type == :forwarding_parameter_node + + yields = body_info.yields + has_block_given = body_info.has_block_given + + if !yields.empty? || block_param || forwarding + required = !has_block_given && !forwarding + + if required && block_param.is_a?(::Prism::BlockParameterNode) + block_name = block_param.name + if body && block_name + usage = NodeUsage.new(body) + if usage.each_conditional_node.any? { |n| n.type == :local_variable_read_node && n.name == block_name } + required = false + end + end + end + + if !yields.empty? + function = Types::Function.empty(untyped) + + yields.each do |yield_node| + args = yield_node.arguments&.arguments || [] + positionals, keywords = + if args.last && keyword_hash?(args.last) + [args[0...-1] || [], args.last] + else + [args, nil] + end + + if (diff = positionals.size - function.required_positionals.size) > 0 + diff.times do + function.required_positionals << Types::Function::Param.new(type: untyped, name: nil) + end + end + + if keywords + elements = + case keywords + when ::Prism::KeywordHashNode, ::Prism::HashNode then keywords.elements + else [] #: Array[Prism::AssocNode | Prism::AssocSplatNode] + end + + elements.each do |elem| #: Prism::node + if elem.is_a?(::Prism::AssocNode) && elem.key.is_a?(::Prism::SymbolNode) + function.required_keywords[elem.key.unescaped.to_sym] ||= + Types::Function::Param.new(type: untyped, name: nil) + end + end + end + end + else + function = Types::UntypedFunction.new(return_type: untyped) + end + + Types::Block.new(required: required, type: function, self_type: nil) + end + end + + private + + def process(node, decls:, comments:, context:) + case node.type + when :program_node + process(node.statements, decls: decls, comments: comments, context: context) + when :statements_node + node.body.each do |child| + process(child, decls: decls, comments: comments, context: context) + end + when :begin_node + if (statements = node.statements) + process(statements, decls: decls, comments: comments, context: context) + end + when :class_node + super_class_name = const_to_name(node.superclass, context: context) + super_class = + if super_class_name + AST::Declarations::Class::Super.new(name: super_class_name, args: [], location: nil) + end + + kls = AST::Declarations::Class.new( + name: const_to_name!(node.constant_path), + super_class: super_class, + type_params: [], + members: [], + annotations: [], + location: nil, + comment: comments[node.location.start_line - 1] + ) + + decls.push(kls) + + new_ctx = context.enter_namespace(kls.name.to_namespace) + if (body = node.body) + process(body, decls: kls.members, comments: comments, context: new_ctx) + end + + remove_unnecessary_accessibility_methods!(kls.members) + sort_members!(kls.members) + when :module_node + mod = AST::Declarations::Module.new( + name: const_to_name!(node.constant_path), + type_params: [], + self_types: [], + members: [], + annotations: [], + location: nil, + comment: comments[node.location.start_line - 1] + ) + + decls.push mod + + new_ctx = context.enter_namespace(mod.name.to_namespace) + if (body = node.body) + process(body, decls: mod.members, comments: comments, context: new_ctx) + end + + remove_unnecessary_accessibility_methods!(mod.members) + sort_members!(mod.members) + when :singleton_class_node + unless node.expression.is_a?(::Prism::SelfNode) + RBS.logger.warn "`class <<` syntax with not-self may be compiled to incorrect code: #{node.expression}" + end + + accessibility = current_accessibility(decls) + ctx = Context.initial.tap { |c| c.singleton = true } + if (body = node.body) + process(body, decls: decls, comments: comments, context: ctx) + end + + decls << accessibility + when :def_node + kind = node.receiver ? :singleton : context.method_kind #: AST::Members::MethodDefinition::kind + body_info = build_body_info(node.body) + + types = [ + MethodType.new( + type_params: [], + type: function_type(node, body_info), + block: block_from_def(node, body_info), + location: nil + ) + ] + + member = AST::Members::MethodDefinition.new( + name: node.name, + location: nil, + annotations: [], + overloads: types.map { |type| AST::Members::MethodDefinition::Overload.new(annotations: [], method_type: type) }, + kind: kind, + comment: comments[node.location.start_line - 1], + overloading: false, + visibility: nil + ) + + decls.push(member) unless decls.include?(member) + + new_ctx = context.update(singleton: kind == :singleton, in_def: true) + if (body = node.body) + process(body, decls: decls, comments: comments, context: new_ctx) + end + when :alias_method_node + new_name = symbol_value(node.new_name) + old_name = symbol_value(node.old_name) + + if new_name && old_name + member = AST::Members::Alias.new( + new_name: new_name, + old_name: old_name, + kind: context.singleton ? :singleton : :instance, + annotations: [], + location: nil, + comment: comments[node.location.start_line - 1], + ) + + decls.push(member) unless decls.include?(member) + end + when :call_node + return if node.block.is_a?(::Prism::BlockNode) + return if node.receiver + + args = node.arguments&.arguments || [] + + case node.name + when :include + args.each do |arg| + if (name = const_to_name(arg, context: context)) + klass = context.singleton ? AST::Members::Extend : AST::Members::Include + decls << klass.new( + name: name, + args: [], + annotations: [], + location: nil, + comment: comments[node.location.start_line - 1] + ) + end + end + when :prepend + args.each do |arg| + if (name = const_to_name(arg, context: context)) + decls << AST::Members::Prepend.new( + name: name, + args: [], + annotations: [], + location: nil, + comment: comments[node.location.start_line - 1] + ) + end + end + when :extend + args.each do |arg| + if (name = const_to_name(arg, context: context)) + decls << AST::Members::Extend.new( + name: name, + args: [], + annotations: [], + location: nil, + comment: comments[node.location.start_line - 1] + ) + end + end + when :attr_reader, :attr_accessor, :attr_writer + klass = + case node.name + when :attr_reader then AST::Members::AttrReader + when :attr_accessor then AST::Members::AttrAccessor + when :attr_writer then AST::Members::AttrWriter + end + + args.each do |arg| + if klass && (name = symbol_value(arg)) + decls << klass.new( + name: name, ivar_name: nil, + type: Types::Bases::Any.new(location: nil), + kind: context.attribute_kind, + location: nil, + comment: comments[node.location.start_line - 1], + annotations: [] + ) + end + end + when :alias_method + if args.size >= 2 && (new_name = symbol_value(args[0])) && (old_name = symbol_value(args[1])) + decls << AST::Members::Alias.new( + new_name: new_name, + old_name: old_name, + kind: context.singleton ? :singleton : :instance, + annotations: [], + location: nil, + comment: comments[node.location.start_line - 1], + ) + end + when :module_function + if args.empty? + context.module_function = true + else + module_func_context = context.update(module_function: true) + args.each do |arg| + if (name = symbol_value(arg)) + if (i, defn = find_def_index_by_name(decls, name)) + if defn.is_a?(AST::Members::MethodDefinition) + decls[i] = defn.update(kind: :singleton_instance) + end + end + else + process(arg, decls: decls, comments: comments, context: module_func_context) + end + end + end + when :public, :private + accessibility = __send__(node.name) + if args.empty? + decls << accessibility + else + args.each do |arg| + if (name = symbol_value(arg)) + if (i, _ = find_def_index_by_name(decls, name)) + current = current_accessibility(decls, i) + if current != accessibility + decls.insert(i + 1, current) + decls.insert(i, accessibility) + end + end + end + end + + current = current_accessibility(decls) + decls << accessibility + args.each do |arg| + process(arg, decls: decls, comments: comments, context: context) + end + decls << current + end + else + args.each do |arg| + process(arg, decls: decls, comments: comments, context: context) + end + end + when :constant_write_node + type = node.value.type == :self_node ? Types::Bases::Any.new(location: nil) : literal_to_type(node.value) + decls << AST::Declarations::Constant.new( + name: TypeName.new(name: node.name, namespace: Namespace.empty), + type: type, + location: nil, + comment: comments[node.location.start_line - 1], + annotations: [] + ) + when :constant_path_write_node + type = node.value.type == :self_node ? Types::Bases::Any.new(location: nil) : literal_to_type(node.value) + + decls << AST::Declarations::Constant.new( + name: const_to_name!(node.target, context: context), + type: type, + location: nil, + comment: comments[node.location.start_line - 1], + annotations: [] + ) + when :instance_variable_write_node, :instance_variable_or_write_node, + :instance_variable_and_write_node, :instance_variable_operator_write_node + case [context.singleton, context.in_def] + when [true, true], [false, false] + member = AST::Members::ClassInstanceVariable.new( + name: node.name, + type: Types::Bases::Any.new(location: nil), + location: nil, + comment: comments[node.location.start_line - 1] + ) + when [false, true] + member = AST::Members::InstanceVariable.new( + name: node.name, + type: Types::Bases::Any.new(location: nil), + location: nil, + comment: comments[node.location.start_line - 1] + ) + when [true, false] + # Singleton class ivar - RBS can't represent it + else + raise 'unreachable' + end + + decls.push(member) if member && !decls.include?(member) + when :class_variable_write_node, :class_variable_or_write_node, + :class_variable_and_write_node, :class_variable_operator_write_node + member = AST::Members::ClassVariable.new( + name: node.name, + type: Types::Bases::Any.new(location: nil), + location: nil, + comment: comments[node.location.start_line - 1] + ) + + decls.push(member) unless decls.include?(member) + when :multi_write_node + (node.lefts + node.rights).each do |target| + if target.is_a?(::Prism::ConstantTargetNode) + decls << AST::Declarations::Constant.new( + name: TypeName.new(name: target.name, namespace: Namespace.empty), + type: Types::Bases::Any.new(location: nil), + location: nil, + comment: comments[node.location.start_line - 1], + annotations: [] + ) + end + end + else + node.each_child_node do |child| + process(child, decls: decls, comments: comments, context: context) + end + end + end + + def const_to_name!(node, context: nil) + case node.type + when :constant_read_node + TypeName.new(name: node.name, namespace: Namespace.empty) + when :constant_path_node + if node.parent.nil? + TypeName.new(name: node.name || raise, namespace: Namespace.root) + else + namespace = const_to_name!(node.parent, context: context).to_namespace + TypeName.new(name: node.name || raise, namespace: namespace) + end + when :self_node + raise if context.nil? + context.namespace.to_type_name + else + raise "Unexpected node for const name: #{node.class}" + end + end + + def const_to_name(node, context:) + return nil unless node + + case node.type + when :self_node + context.namespace.to_type_name + when :constant_read_node, :constant_path_node + const_to_name!(node) rescue nil + end + end + + def symbol_value(node) + case node.type + when :symbol_node then node.unescaped.to_sym + when :string_node then node.unescaped.to_sym + end + end + + def build_body_info(body) + yields = [] #: Array[::Prism::YieldNode] + has_block_given = false + returns = [] #: Array[::Prism::ReturnNode] + + if body + queue = [body] + while (node = queue.shift) + if node.is_a?(::Prism::YieldNode) + yields << node + elsif node.is_a?(::Prism::ReturnNode) + returns << node + elsif node.is_a?(::Prism::CallNode) + if node.name == :block_given? && node.receiver.nil? && node.arguments.nil? + has_block_given = true + end + end + + node.each_child_node { |child| queue << child } + end + end + + BodyInfo.new(yields: yields, has_block_given: has_block_given, returns: returns.empty? ? nil : returns) + end + + def function_type(node, body_info) + params = node.parameters + return_type = + if node.name == :initialize + Types::Bases::Void.new(location: nil) + else + return_type_from_body(node.body, returns: body_info.returns) + end + + fun = Types::Function.empty(return_type) + return fun unless params + + if params.keyword_rest&.type == :forwarding_parameter_node + return fun.update( + rest_positionals: Types::Function::Param.new(name: nil, type: untyped), + rest_keywords: Types::Function::Param.new(name: nil, type: untyped) + ) + end + + params.requireds.each do |req| #: Prism::RequiredParameterNode | Prism::MultiTargetNode + name = req.is_a?(::Prism::RequiredParameterNode) ? req.name : nil + fun.required_positionals << Types::Function::Param.new(name: name, type: untyped) + end + + params.optionals.each do |opt| + fun.optional_positionals << Types::Function::Param.new( + name: opt.name, type: param_type(opt.value) + ) + end + + if (rest = params.rest).is_a?(::Prism::RestParameterNode) + fun = fun.update(rest_positionals: Types::Function::Param.new(name: rest.name, type: untyped)) + end + + params.posts.each do |post| #: Prism::RequiredParameterNode | Prism::MultiTargetNode + name = post.is_a?(::Prism::RequiredParameterNode) ? post.name : nil + fun.trailing_positionals << Types::Function::Param.new(name: name, type: untyped) + end + + params.keywords.each do |kw| + case kw + when ::Prism::RequiredKeywordParameterNode + fun.required_keywords[kw.name] = Types::Function::Param.new(name: nil, type: untyped) + when ::Prism::OptionalKeywordParameterNode + fun.optional_keywords[kw.name] = Types::Function::Param.new(name: nil, type: param_type(kw.value)) + end + end + + if (keyword_rest = params.keyword_rest).is_a?(::Prism::KeywordRestParameterNode) + fun = fun.update(rest_keywords: Types::Function::Param.new(name: keyword_rest.name, type: untyped)) + end + + fun + end + + def keyword_hash?(node) + case node.type + when :keyword_hash_node, :hash_node + node.elements.all? { |e| e.type == :assoc_node && e.key.type == :symbol_node } + else + false + end + end + + def return_type_from_body(body, returns:) + return Types::Bases::Nil.new(location: nil) unless body + + if body.type == :statements_node && body.body.size == 1 + return return_type_from_body(body.body.first, returns: returns) + end + + case body.type + when :if_node, :unless_node + if_unless_type(body) + when :statements_node + statements_type(body, returns: returns) + else + literal_to_type(body) + end + end + + def if_unless_type(node) + case node.type + when :if_node + true_type = return_type_from_body(node.statements, returns: nil) + false_type = + case (subsequent = node.subsequent)&.type + when :else_node + return_type_from_body(subsequent.statements, returns: nil) + when :if_node + if_unless_type(subsequent) + else + Types::Bases::Nil.new(location: nil) + end + + types_to_union_type([true_type, false_type]) + when :unless_node + true_type = return_type_from_body(node.statements, returns: nil) + false_type = + if (else_clause = node.else_clause) + return_type_from_body(else_clause.statements, returns: nil) + else + Types::Bases::Nil.new(location: nil) + end + + types_to_union_type([true_type, false_type]) + else + untyped + end + end + + def statements_type(node, returns:) + return Types::Bases::Nil.new(location: nil) unless node + + return_nodes = returns || node.find_all { |n| n.is_a?(::Prism::ReturnNode) } + + return_types = return_nodes.map do |return_node| + args = return_node.arguments&.arguments + if args && !args.empty? + literal_to_type(args.first) + else + Types::Bases::Nil.new(location: nil) + end + end + + last_node = node.body.last + last_evaluated = last_node ? literal_to_type(last_node) : Types::Bases::Nil.new(location: nil) + + types_to_union_type([*return_types, last_evaluated]) + end + + def literal_to_type(node) + case node.type + when :string_node + if (unescaped = node.unescaped).ascii_only? + Types::Literal.new(literal: unescaped, location: nil) + else + BuiltinNames::String.instance_type + end + when :interpolated_string_node, :x_string_node, :interpolated_x_string_node + BuiltinNames::String.instance_type + when :symbol_node + if (unescaped = node.unescaped).ascii_only? + Types::Literal.new(literal: unescaped.to_sym, location: nil) + else + BuiltinNames::Symbol.instance_type + end + when :interpolated_symbol_node + BuiltinNames::Symbol.instance_type + when :regular_expression_node, :interpolated_regular_expression_node + BuiltinNames::Regexp.instance_type + when :true_node + Types::Literal.new(literal: true, location: nil) + when :false_node + Types::Literal.new(literal: false, location: nil) + when :nil_node + Types::Bases::Nil.new(location: nil) + when :integer_node + Types::Literal.new(literal: node.value, location: nil) + when :float_node + BuiltinNames::Float.instance_type + when :rational_node + Types::ClassInstance.new(name: TypeName.new(name: :Rational, namespace: Namespace.root), args: [], location: nil) + when :imaginary_node + Types::ClassInstance.new(name: TypeName.new(name: :Complex, namespace: Namespace.root), args: [], location: nil) + when :array_node + if node.elements.empty? + BuiltinNames::Array.instance_type(untyped) + else + BuiltinNames::Array.instance_type(types_to_union_type(node.elements.map { |e| literal_to_type(e) })) + end + when :range_node + types = [node.left, node.right].compact.map { |c| literal_to_type(c) } + BuiltinNames::Range.instance_type(range_element_type(types)) + when :hash_node + hash_type(node.elements) + when :self_node + Types::Bases::Self.new(location: nil) + when :call_node + if node.receiver + case node.name + when :freeze, :tap, :itself, :dup, :clone, :taint, :untaint, :extend + literal_to_type(node.receiver) + else + untyped + end + else + untyped + end + when :parentheses_node + node.body ? literal_to_type(node.body) : Types::Bases::Nil.new(location: nil) + when :statements_node + node.body.last ? literal_to_type(node.body.last) : Types::Bases::Nil.new(location: nil) + when :if_node, :unless_node + if_unless_type(node) || untyped + else + untyped + end + end + + def hash_type(elements) + key_types = [] #: Array[Types::t] + value_types = [] #: Array[Types::t] + + elements.each do |elem| + case elem.type + when :assoc_node + key_types << literal_to_type(elem.key) + value_types << literal_to_type(elem.value) + when :assoc_splat_node + key_types << untyped + value_types << untyped + end + end + + if !key_types.empty? && key_types.all? { |t| t.is_a?(Types::Literal) } + fields = key_types.map { |t| + t.is_a?(Types::Literal) or raise + t.literal + }.zip(value_types).to_h #: Hash[Types::Literal::literal, Types::t] + + Types::Record.new(fields: fields, location: nil) + else + BuiltinNames::Hash.instance_type(types_to_union_type(key_types), types_to_union_type(value_types)) + end + end + + def param_type(node, default: Types::Bases::Any.new(location: nil)) + case node.type + when :integer_node + BuiltinNames::Integer.instance_type + when :float_node + BuiltinNames::Float.instance_type + when :rational_node + Types::ClassInstance.new(name: TypeName.parse("::Rational"), args: [], location: nil) + when :imaginary_node + Types::ClassInstance.new(name: TypeName.parse("::Complex"), args: [], location: nil) + when :symbol_node + BuiltinNames::Symbol.instance_type + when :string_node, :interpolated_string_node + BuiltinNames::String.instance_type + when :nil_node + Types::Optional.new(type: Types::Bases::Any.new(location: nil), location: nil) + when :true_node, :false_node + Types::Bases::Bool.new(location: nil) + when :array_node + BuiltinNames::Array.instance_type(default) + when :hash_node + BuiltinNames::Hash.instance_type(default, default) + else + default + end + end + end + end + end +end diff --git a/lib/rbs/prototype/rb/ruby_vm.rb b/lib/rbs/prototype/rb/ruby_vm.rb new file mode 100644 index 0000000000..fe18a1dfe9 --- /dev/null +++ b/lib/rbs/prototype/rb/ruby_vm.rb @@ -0,0 +1,609 @@ +# frozen_string_literal: true + +module RBS + module Prototype + module RB + class RubyVM < Base + include RubyVMHelpers + + def parse(string) + # @type var comments: Hash[Integer, AST::Comment] + comments = parse_comments(string, include_trailing: false) + process ::RubyVM::AbstractSyntaxTree.parse(string), decls: source_decls, comments: comments, context: Context.initial + end + + def process(node, decls:, comments:, context:) + case node.type + when :CLASS + class_name, super_class_node, *class_body = node.children + super_class_name = const_to_name(super_class_node, context: context) + super_class = + if super_class_name + AST::Declarations::Class::Super.new(name: super_class_name, args: [], location: nil) + else + nil + end + kls = AST::Declarations::Class.new( + name: const_to_name!(class_name), + super_class: super_class, + type_params: [], + members: [], + annotations: [], + location: nil, + comment: comments[node.first_lineno - 1] + ) + + decls.push kls + + new_ctx = context.enter_namespace(kls.name.to_namespace) + each_node class_body do |child| + process child, decls: kls.members, comments: comments, context: new_ctx + end + remove_unnecessary_accessibility_methods! kls.members + sort_members! kls.members + + when :MODULE + module_name, *module_body = node.children + + mod = AST::Declarations::Module.new( + name: const_to_name!(module_name), + type_params: [], + self_types: [], + members: [], + annotations: [], + location: nil, + comment: comments[node.first_lineno - 1] + ) + + decls.push mod + + new_ctx = context.enter_namespace(mod.name.to_namespace) + each_node module_body do |child| + process child, decls: mod.members, comments: comments, context: new_ctx + end + remove_unnecessary_accessibility_methods! mod.members + sort_members! mod.members + + when :SCLASS + this, body = node.children + + if this.type != :SELF + RBS.logger.warn "`class <<` syntax with not-self may be compiled to incorrect code: #{this}" + end + + accessibility = current_accessibility(decls) + + ctx = Context.initial.tap { |ctx| ctx.singleton = true } + process_children(body, decls: decls, comments: comments, context: ctx) + + decls << accessibility + + when :DEFN, :DEFS + # @type var kind: Context::method_kind + + if node.type == :DEFN + def_name, def_body = node.children + kind = context.method_kind + else + _, def_name, def_body = node.children + kind = :singleton + end + + types = [ + MethodType.new( + type_params: [], + type: function_type_from_body(def_body, def_name), + block: block_from_body(def_body), + location: nil + ) + ] + + member = AST::Members::MethodDefinition.new( + name: def_name, + location: nil, + annotations: [], + overloads: types.map {|type| AST::Members::MethodDefinition::Overload.new(annotations: [], method_type: type )}, + kind: kind, + comment: comments[node.first_lineno - 1], + overloading: false, + visibility: nil + ) + + decls.push member unless decls.include?(member) + + new_ctx = context.update(singleton: kind == :singleton, in_def: true) + each_node def_body.children do |child| + process child, decls: decls, comments: comments, context: new_ctx + end + + when :ALIAS + new_name, old_name = node.children.map { |c| literal_to_symbol(c) } + new_name or raise + old_name or raise + member = AST::Members::Alias.new( + new_name: new_name, + old_name: old_name, + kind: context.singleton ? :singleton : :instance, + annotations: [], + location: nil, + comment: comments[node.first_lineno - 1], + ) + decls.push member unless decls.include?(member) + + when :FCALL, :VCALL + args = node.children[1]&.children || [] + + case node.children[0] + when :include + args.each do |arg| + if (name = const_to_name(arg, context: context)) + klass = context.singleton ? AST::Members::Extend : AST::Members::Include + decls << klass.new( + name: name, args: [], annotations: [], + location: nil, comment: comments[node.first_lineno - 1] + ) + end + end + when :prepend + args.each do |arg| + if (name = const_to_name(arg, context: context)) + decls << AST::Members::Prepend.new( + name: name, args: [], annotations: [], + location: nil, comment: comments[node.first_lineno - 1] + ) + end + end + when :extend + args.each do |arg| + if (name = const_to_name(arg, context: context)) + decls << AST::Members::Extend.new( + name: name, args: [], annotations: [], + location: nil, comment: comments[node.first_lineno - 1] + ) + end + end + when :attr_reader + args.each do |arg| + if arg && (name = literal_to_symbol(arg)) + decls << AST::Members::AttrReader.new( + name: name, ivar_name: nil, + type: Types::Bases::Any.new(location: nil), + kind: context.attribute_kind, + location: nil, comment: comments[node.first_lineno - 1], + annotations: [] + ) + end + end + when :attr_accessor + args.each do |arg| + if arg && (name = literal_to_symbol(arg)) + decls << AST::Members::AttrAccessor.new( + name: name, ivar_name: nil, + type: Types::Bases::Any.new(location: nil), + kind: context.attribute_kind, + location: nil, comment: comments[node.first_lineno - 1], + annotations: [] + ) + end + end + when :attr_writer + args.each do |arg| + if arg && (name = literal_to_symbol(arg)) + decls << AST::Members::AttrWriter.new( + name: name, ivar_name: nil, + type: Types::Bases::Any.new(location: nil), + kind: context.attribute_kind, + location: nil, comment: comments[node.first_lineno - 1], + annotations: [] + ) + end + end + when :alias_method + if args[0] && args[1] && (new_name = literal_to_symbol(args[0])) && (old_name = literal_to_symbol(args[1])) + decls << AST::Members::Alias.new( + new_name: new_name, old_name: old_name, + kind: context.singleton ? :singleton : :instance, + annotations: [], location: nil, + comment: comments[node.first_lineno - 1], + ) + end + when :module_function + if args.empty? + context.module_function = true + else + module_func_context = context.update(module_function: true) + args.each do |arg| + if arg && (name = literal_to_symbol(arg)) + if (i, defn = find_def_index_by_name(decls, name)) + if defn.is_a?(AST::Members::MethodDefinition) + decls[i] = defn.update(kind: :singleton_instance) + end + end + elsif arg + process arg, decls: decls, comments: comments, context: module_func_context + end + end + end + when :public, :private + accessibility = __send__(node.children[0]) + if args.empty? + decls << accessibility + else + args.each do |arg| + if arg && (name = literal_to_symbol(arg)) + if (i, _ = find_def_index_by_name(decls, name)) + current = current_accessibility(decls, i) + if current != accessibility + decls.insert(i + 1, current) + decls.insert(i, accessibility) + end + end + end + end + + current = current_accessibility(decls) + decls << accessibility + process_children(node, decls: decls, comments: comments, context: context) + decls << current + end + else + process_children(node, decls: decls, comments: comments, context: context) + end + + when :ITER + # ignore + + when :CDECL + const_name = case + when node.children[0].is_a?(Symbol) + TypeName.new(name: node.children[0], namespace: Namespace.empty) + else + const_to_name!(node.children[0], context: context) + end + + value_node = node.children.last + type = if value_node.nil? || value_node.type == :SELF + Types::Bases::Any.new(location: nil) + else + literal_to_type(value_node) + end + decls << AST::Declarations::Constant.new( + name: const_name, type: type, location: nil, + comment: comments[node.first_lineno - 1], + annotations: [] + ) + + when :IASGN + case [context.singleton, context.in_def] + when [true, true], [false, false] + member = AST::Members::ClassInstanceVariable.new( + name: node.children.first, + type: Types::Bases::Any.new(location: nil), + location: nil, comment: comments[node.first_lineno - 1] + ) + when [false, true] + member = AST::Members::InstanceVariable.new( + name: node.children.first, + type: Types::Bases::Any.new(location: nil), + location: nil, comment: comments[node.first_lineno - 1] + ) + when [true, false] + # Singleton class ivar - RBS can't represent it + else + raise 'unreachable' + end + + decls.push member if member && !decls.include?(member) + + when :CVASGN + member = AST::Members::ClassVariable.new( + name: node.children.first, + type: Types::Bases::Any.new(location: nil), + location: nil, comment: comments[node.first_lineno - 1] + ) + decls.push member unless decls.include?(member) + else + process_children(node, decls: decls, comments: comments, context: context) + end + end + + def process_children(node, decls:, comments:, context:) + each_child node do |child| + process child, decls: decls, comments: comments, context: context + end + end + + def const_to_name!(node, context: nil) + case node.type + when :CONST + TypeName.new(name: node.children[0], namespace: Namespace.empty) + when :COLON2 + if node.children[0] + namespace = const_to_name!(node.children[0], context: context).to_namespace + else + namespace = Namespace.empty + end + TypeName.new(name: node.children[1], namespace: namespace) + when :COLON3 + TypeName.new(name: node.children[0], namespace: Namespace.root) + when :SELF + raise if context.nil? + context.namespace.to_type_name + else + raise + end + end + + def const_to_name(node, context:) + if node + case node.type + when :SELF + context.namespace.to_type_name + when :CONST, :COLON2, :COLON3 + const_to_name!(node) rescue nil + end + end + end + + def literal_to_symbol(node) + case node.type + when :SYM + node.children[0] + when :LIT + node.children[0] if node.children[0].is_a?(Symbol) + when :STR + node.children[0].to_sym + end + end + + def function_type_from_body(node, def_name) + table_node, args_node, *_ = node.children + + pre_num, _pre_init, opt, _first_post, post_num, _post_init, rest, kw, kwrest, _block = args_from_node(args_node) + + return_type = if def_name == :initialize + Types::Bases::Void.new(location: nil) + else + function_return_type_from_body(node) + end + + fun = Types::Function.empty(return_type) + + table_node.take(pre_num).each do |name| + fun.required_positionals << Types::Function::Param.new(name: name, type: untyped) + end + + while opt&.type == :OPT_ARG + lvasgn, opt = opt.children + name = lvasgn.children[0] + fun.optional_positionals << Types::Function::Param.new( + name: name, + type: param_type(lvasgn.children[1]) + ) + end + + if rest + rest_name = rest == :* ? nil : rest + fun = fun.update(rest_positionals: Types::Function::Param.new(name: rest_name, type: untyped)) + end + + table_node.drop(fun.required_positionals.size + fun.optional_positionals.size + (fun.rest_positionals ? 1 : 0)).take(post_num).each do |name| + fun.trailing_positionals << Types::Function::Param.new(name: name, type: untyped) + end + + while kw + lvasgn, kw = kw.children + name, value = lvasgn.children + + case value + when nil, :NODE_SPECIAL_REQUIRED_KEYWORD + fun.required_keywords[name] = Types::Function::Param.new(name: nil, type: untyped) + when ::RubyVM::AbstractSyntaxTree::Node + fun.optional_keywords[name] = Types::Function::Param.new(name: nil, type: param_type(value)) + else + raise "Unexpected keyword arg value: #{value}" + end + end + + if kwrest && kwrest.children.any? + kwrest_name = kwrest.children[0] #: Symbol? + kwrest_name = nil if kwrest_name == :** + fun = fun.update(rest_keywords: Types::Function::Param.new(name: kwrest_name, type: untyped)) + end + + fun + end + + def function_return_type_from_body(node) + body = node.children[2] + body_type(body) + end + + def body_type(node) + return Types::Bases::Nil.new(location: nil) unless node + + case node.type + when :IF, :UNLESS + if_unless_type(node) + when :BLOCK + block_type(node) + else + literal_to_type(node) + end + end + + def if_unless_type(node) + raise unless node.type == :IF || node.type == :UNLESS + + _exp_node, true_node, false_node = node.children + types_to_union_type([body_type(true_node), body_type(false_node)]) + end + + def block_type(node) + raise unless node.type == :BLOCK + + return_stmts = any_node?(node) do |n| + n.type == :RETURN + end&.map do |return_node| + returned_value = return_node.children[0] + returned_value ? literal_to_type(returned_value) : Types::Bases::Nil.new(location: nil) + end || [] + last_node = node.children.last + last_evaluated = last_node ? literal_to_type(last_node) : Types::Bases::Nil.new(location: nil) + types_to_union_type([*return_stmts, last_evaluated]) + end + + def literal_to_type(node) + case node.type + when :STR + lit = node.children[0] + if lit.ascii_only? + Types::Literal.new(literal: lit, location: nil) + else + BuiltinNames::String.instance_type + end + when :DSTR, :XSTR + BuiltinNames::String.instance_type + when :SYM + lit = node.children[0] + if lit.to_s.ascii_only? + Types::Literal.new(literal: lit, location: nil) + else + BuiltinNames::Symbol.instance_type + end + when :DSYM + BuiltinNames::Symbol.instance_type + when :DREGX, :REGX + BuiltinNames::Regexp.instance_type + when :TRUE + Types::Literal.new(literal: true, location: nil) + when :FALSE + Types::Literal.new(literal: false, location: nil) + when :NIL + Types::Bases::Nil.new(location: nil) + when :INTEGER + Types::Literal.new(literal: node.children[0], location: nil) + when :FLOAT + BuiltinNames::Float.instance_type + when :RATIONAL, :IMAGINARY + lit = node.children[0] + type_name = TypeName.new(name: lit.class.name.to_sym, namespace: Namespace.root) + Types::ClassInstance.new(name: type_name, args: [], location: nil) + when :LIT + lit = node.children[0] + case lit + when Symbol + if lit.to_s.ascii_only? + Types::Literal.new(literal: lit, location: nil) + else + BuiltinNames::Symbol.instance_type + end + when Integer + Types::Literal.new(literal: lit, location: nil) + when String + Types::Literal.new(literal: lit, location: nil) + else + type_name = TypeName.new(name: lit.class.name.to_sym, namespace: Namespace.root) + Types::ClassInstance.new(name: type_name, args: [], location: nil) + end + when :ZLIST, :ZARRAY + BuiltinNames::Array.instance_type(untyped) + when :LIST, :ARRAY + elem_types = node.children.compact.map { |e| literal_to_type(e) } + t = types_to_union_type(elem_types) + BuiltinNames::Array.instance_type(t) + when :DOT2, :DOT3 + types = node.children.map { |c| literal_to_type(c) } + type = range_element_type(types) + BuiltinNames::Range.instance_type(type) + when :HASH + list = node.children[0] + if list + children = list.children + children.pop + else + children = [] #: Array[untyped] + end + + key_types = [] #: Array[Types::t] + value_types = [] #: Array[Types::t] + children.each_slice(2) do |k, v| + if k + key_types << literal_to_type(k) + value_types << literal_to_type(v) + else + key_types << untyped + value_types << untyped + end + end + + if !key_types.empty? && key_types.all? { |t| t.is_a?(Types::Literal) } + fields = key_types.map {|t| + t.is_a?(Types::Literal) or raise + t.literal + }.zip(value_types).to_h #: Hash[Types::Literal::literal, Types::t] + Types::Record.new(fields: fields, location: nil) + else + key_type = types_to_union_type(key_types) + value_type = types_to_union_type(value_types) + BuiltinNames::Hash.instance_type(key_type, value_type) + end + when :SELF + Types::Bases::Self.new(location: nil) + when :CALL + receiver, method_name, * = node.children + case method_name + when :freeze, :tap, :itself, :dup, :clone, :taint, :untaint, :extend + literal_to_type(receiver) + else + untyped + end + else + untyped + end + end + + def param_type(node, default: Types::Bases::Any.new(location: nil)) + case node.type + when :INTEGER + BuiltinNames::Integer.instance_type + when :FLOAT + BuiltinNames::Float.instance_type + when :RATIONAL + Types::ClassInstance.new(name: TypeName.parse("::Rational"), args: [], location: nil) + when :IMAGINARY + Types::ClassInstance.new(name: TypeName.parse("::Complex"), args: [], location: nil) + when :LIT + case node.children[0] + when Symbol + BuiltinNames::Symbol.instance_type + when Integer + BuiltinNames::Integer.instance_type + when Float + BuiltinNames::Float.instance_type + else + default + end + when :SYM + BuiltinNames::Symbol.instance_type + when :STR, :DSTR + BuiltinNames::String.instance_type + when :NIL + Types::Optional.new( + type: Types::Bases::Any.new(location: nil), + location: nil + ) + when :TRUE, :FALSE + Types::Bases::Bool.new(location: nil) + when :ARRAY, :LIST + BuiltinNames::Array.instance_type(default) + when :HASH + BuiltinNames::Hash.instance_type(default, default) + else + default + end + end + end + end + end +end diff --git a/lib/rbs/prototype/rbi.rb b/lib/rbs/prototype/rbi.rb index 2359e9529e..a5dc21b1d4 100644 --- a/lib/rbs/prototype/rbi.rb +++ b/lib/rbs/prototype/rbi.rb @@ -2,624 +2,97 @@ module RBS module Prototype - class RBI - include Helpers - - attr_reader :decls - attr_reader :modules - attr_reader :last_sig - - def initialize - @decls = [] - - @modules = [] - end - - def parse(string) - comments = parse_comments(string, include_trailing: true) - process RubyVM::AbstractSyntaxTree.parse(string), comments: comments - end - - def nested_name(name) - (current_namespace + const_to_name(name).to_namespace).to_type_name.relative! - end - - def current_namespace - modules.inject(Namespace.empty) do |parent, mod| - parent + mod.name.to_namespace - end - end - - def push_class(name, super_class, comment:) - class_decl = AST::Declarations::Class.new( - name: nested_name(name), - super_class: super_class && AST::Declarations::Class::Super.new(name: const_to_name(super_class), args: [], location: nil), - type_params: [], - members: [], - annotations: [], - location: nil, - comment: comment - ) - - modules << class_decl - decls << class_decl - - yield - ensure - modules.pop - end - - def push_module(name, comment:) - module_decl = AST::Declarations::Module.new( - name: nested_name(name), - type_params: [], - members: [], - annotations: [], - location: nil, - self_types: [], - comment: comment - ) - - modules << module_decl - decls << module_decl - - yield - ensure - modules.pop - end - - def current_module - modules.last - end - - def current_module! - current_module or raise - end - - def push_sig(node) - if last_sig = @last_sig - last_sig << node + module RBI + def self.new + if ENV['RBS_RUBY_PARSER'] == 'prism' + RBI::Prism.new else - @last_sig = [node] + RBI::RubyVM.new end end - def pop_sig - @last_sig.tap do - @last_sig = nil - end - end - - def join_comments(nodes, comments) - cs = nodes.map {|node| comments[node.first_lineno - 1] }.compact - AST::Comment.new(string: cs.map(&:string).join("\n"), location: nil) - end + class Base + include CommentParser - def process(node, outer: [], comments:) - case node.type - when :CLASS - comment = comments[node.first_lineno - 1] - push_class node.children[0], node.children[1], comment: comment do - process node.children[2], outer: outer + [node], comments: comments - end - when :MODULE - comment = comments[node.first_lineno - 1] - push_module node.children[0], comment: comment do - process node.children[1], outer: outer + [node], comments: comments - end - when :FCALL - case node.children[0] - when :include - each_arg node.children[1] do |arg| - if arg.type == :CONST || arg.type == :COLON2 || arg.type == :COLON3 - name = const_to_name(arg) - include_member = AST::Members::Include.new( - name: name, - args: [], - annotations: [], - location: nil, - comment: nil - ) - current_module!.members << include_member - end - end - when :extend - each_arg node.children[1] do |arg| - if arg.type == :CONST || arg.type == :COLON2 - name = const_to_name(arg) - unless name.to_s == "T::Generic" || name.to_s == "T::Sig" - member = AST::Members::Extend.new( - name: name, - args: [], - annotations: [], - location: nil, - comment: nil - ) - current_module!.members << member - end - end - end - when :sig - out = outer.last or raise - push_sig out.children.last.children.last - when :alias_method - new, old = each_arg(node.children[1]).map {|x| x.children[0] } - current_module!.members << AST::Members::Alias.new( - new_name: new, - old_name: old, - location: nil, - annotations: [], - kind: :instance, - comment: nil - ) - end - when :DEFS - sigs = pop_sig + attr_reader :decls + attr_reader :modules + attr_reader :last_sig - if sigs - comment = join_comments(sigs, comments) - - args = node.children[2] - types = sigs.map {|sig| method_type(args, sig, variables: current_module!.type_params, overloads: sigs.size) }.compact - - current_module!.members << AST::Members::MethodDefinition.new( - name: node.children[1], - location: nil, - annotations: [], - overloads: types.map {|type| AST::Members::MethodDefinition::Overload.new(annotations: [], method_type: type) }, - kind: :singleton, - comment: comment, - overloading: false, - visibility: nil - ) - end - - when :DEFN - sigs = pop_sig - - if sigs - comment = join_comments(sigs, comments) - - args = node.children[1] - types = sigs.map {|sig| method_type(args, sig, variables: current_module!.type_params, overloads: sigs.size) }.compact - - current_module!.members << AST::Members::MethodDefinition.new( - name: node.children[0], - location: nil, - annotations: [], - overloads: types.map {|type| AST::Members::MethodDefinition::Overload.new(annotations: [], method_type: type) }, - kind: :instance, - comment: comment, - overloading: false, - visibility: nil - ) - end - - when :CDECL - if (send = node.children.last) && send.type == :FCALL && send.children[0] == :type_member - unless each_arg(send.children[1]).any? {|node| - node.type == :HASH && - each_arg(node.children[0]).each_slice(2).any? {|a, _| symbol_literal_node?(a) == :fixed } - } - # @type var variance: AST::TypeParam::variance? - if (a0 = each_arg(send.children[1]).to_a[0]) && (v = symbol_literal_node?(a0)) - variance = case v - when :out - :covariant - when :in - :contravariant - end - end - - current_module!.type_params << AST::TypeParam.new( - name: node.children[0], - variance: variance || :invariant, - location: nil, - upper_bound: nil, - lower_bound: nil, - default_type: nil - ) - end - else - name = node.children[0].yield_self do |n| - if n.is_a?(Symbol) - TypeName.new(namespace: current_namespace, name: n) - else - const_to_name(n) - end - end - value_node = node.children.last - type = if value_node && value_node.type == :CALL && value_node.children[1] == :let - type_node = each_arg(value_node.children[2]).to_a[1] - type_of type_node, variables: current_module&.type_params || [] - else - Types::Bases::Any.new(location: nil) - end - decls << AST::Declarations::Constant.new( - name: name, - type: type, - location: nil, - comment: nil, - annotations: [] - ) - end - when :ALIAS - current_module!.members << AST::Members::Alias.new( - new_name: node.children[0].children[0], - old_name: node.children[1].children[0], - location: nil, - annotations: [], - kind: :instance, - comment: nil - ) - else - each_child node do |child| - process child, outer: outer + [node], comments: comments - end + def initialize + @decls = [] + @modules = [] end - end - - def method_type(args_node, type_node, variables:, overloads:) - if type_node - if type_node.type == :CALL - method_type = method_type(args_node, type_node.children[0], variables: variables, overloads: overloads) or raise - else - method_type = MethodType.new( - type: Types::Function.empty(Types::Bases::Any.new(location: nil)), - block: nil, - location: nil, - type_params: [] - ) - end - - name, args = case type_node.type - when :CALL - [ - type_node.children[1], - type_node.children[2] - ] - when :FCALL, :VCALL - [ - type_node.children[0], - type_node.children[1] - ] - end - - case name - when :returns - return_type = each_arg(args).to_a[0] - method_type.update(type: method_type.type.with_return_type(type_of(return_type, variables: variables))) - when :params - if args_node - parse_params(args_node, args, method_type, variables: variables, overloads: overloads) - else - vars = (node_to_hash(each_arg(args).to_a[0]) || {}).transform_values {|value| type_of(value, variables: variables) } - required_positionals = vars.map do |name, type| - Types::Function::Param.new(name: name, type: type) - end - - if method_type.type.is_a?(RBS::Types::Function) - method_type.update(type: method_type.type.update(required_positionals: required_positionals)) - else - method_type - end - end - when :type_parameters - type_params = [] #: Array[AST::TypeParam] - - each_arg args do |node| - if name = symbol_literal_node?(node) - type_params << AST::TypeParam.new( - name: name, - variance: :invariant, - upper_bound: nil, - lower_bound: nil, - location: nil, - default_type: nil - ) - end - end - - method_type.update(type_params: type_params) - when :void - method_type.update(type: method_type.type.with_return_type(Types::Bases::Void.new(location: nil))) - when :proc - method_type - else - method_type - end + def nested_name(name) + (current_namespace + const_to_name(name).to_namespace).to_type_name.relative! end - end - - def parse_params(args_node, args, method_type, variables:, overloads:) - vars = (node_to_hash(each_arg(args).to_a[0]) || {}).transform_values {|value| type_of(value, variables: variables) } - # @type var required_positionals: Array[Types::Function::Param] - required_positionals = [] - # @type var optional_positionals: Array[Types::Function::Param] - optional_positionals = [] - # @type var rest_positionals: Types::Function::Param? - rest_positionals = nil - # @type var trailing_positionals: Array[Types::Function::Param] - trailing_positionals = [] - # @type var required_keywords: Hash[Symbol, Types::Function::Param] - required_keywords = {} - # @type var optional_keywords: Hash[Symbol, Types::Function::Param] - optional_keywords = {} - # @type var rest_keywords: Types::Function::Param? - rest_keywords = nil - - var_names = args_node.children[0] - pre_num, _pre_init, opt, _first_post, post_num, _post_init, rest, kw, kwrest, block = args_node.children[1].children - - pre_num.times.each do |i| - name = var_names[i] - type = vars[name] || Types::Bases::Any.new(location: nil) - required_positionals << Types::Function::Param.new(type: type, name: name) - end - - index = pre_num - while opt - name = var_names[index] - if (type = vars[name]) - optional_positionals << Types::Function::Param.new(type: type, name: name) + def current_namespace + modules.inject(Namespace.empty) do |parent, mod| + parent + mod.name.to_namespace end - index += 1 - opt = opt.children[1] end - if rest - name = var_names[index] - if (type = vars[name]) - rest_positionals = Types::Function::Param.new(type: type, name: name) - end - index += 1 - end - - post_num.times do |i| - name = var_names[i+index] - if (type = vars[name]) - trailing_positionals << Types::Function::Param.new(type: type, name: name) - end - index += 1 - end - - while kw - name, value = kw.children[0].children - if (type = vars[name]) - if value - optional_keywords[name] = Types::Function::Param.new(type: type, name: name) - else - required_keywords[name] = Types::Function::Param.new(type: type, name: name) - end - end - - kw = kw.children[1] - end + def push_class(name, super_class, comment:) + class_decl = AST::Declarations::Class.new( + name: nested_name(name), + super_class: super_class && AST::Declarations::Class::Super.new(name: const_to_name(super_class), args: [], location: nil), + type_params: [], + members: [], + annotations: [], + location: nil, + comment: comment + ) - if kwrest - name = kwrest.children[0] - if (type = vars[name]) - rest_keywords = Types::Function::Param.new(type: type, name: name) - end - end + modules << class_decl + decls << class_decl - method_block = nil - if block - if (type = vars[block]) - if type.is_a?(Types::Proc) - method_block = Types::Block.new(required: true, type: type.type, self_type: nil) - elsif type.is_a?(Types::Bases::Any) - method_block = Types::Block.new( - required: true, - type: Types::Function.empty(Types::Bases::Any.new(location: nil)), - self_type: nil - ) - # Handle an optional block like `T.nilable(T.proc.void)`. - elsif type.is_a?(Types::Optional) && (proc_type = type.type).is_a?(Types::Proc) - method_block = Types::Block.new(required: false, type: proc_type.type, self_type: nil) - else - STDERR.puts "Unexpected block type: #{type}" - PP.pp args_node, STDERR - method_block = Types::Block.new( - required: true, - type: Types::Function.empty(Types::Bases::Any.new(location: nil)), - self_type: nil - ) - end - else - if overloads == 1 - method_block = Types::Block.new( - required: false, - type: Types::Function.empty(Types::Bases::Any.new(location: nil)), - self_type: nil - ) - end - end + yield + ensure + modules.pop end - if method_type.type.is_a?(Types::Function) - method_type.update( - type: method_type.type.update( - required_positionals: required_positionals, - optional_positionals: optional_positionals, - rest_positionals: rest_positionals, - trailing_positionals: trailing_positionals, - required_keywords: required_keywords, - optional_keywords: optional_keywords, - rest_keywords: rest_keywords - ), - block: method_block + def push_module(name, comment:) + module_decl = AST::Declarations::Module.new( + name: nested_name(name), + type_params: [], + members: [], + annotations: [], + location: nil, + self_types: [], + comment: comment ) - else - method_type - end - end - def type_of(type_node, variables:) - type = type_of0(type_node, variables: variables) + modules << module_decl + decls << module_decl - case - when type.is_a?(Types::ClassInstance) && type.name.name == BuiltinNames::BasicObject.name.name - Types::Bases::Any.new(location: nil) - when type.is_a?(Types::ClassInstance) && type.name.to_s == "T::Boolean" - Types::Bases::Bool.new(location: nil) - else - type + yield + ensure + modules.pop end - end - - def type_of0(type_node, variables:) - case - when type_node.type == :CONST - if variables.include?(type_node.children[0]) - Types::Variable.new(name: type_node.children[0], location: nil) - else - Types::ClassInstance.new(name: const_to_name(type_node), args: [], location: nil) - end - when type_node.type == :COLON2 || type_node.type == :COLON3 - Types::ClassInstance.new(name: const_to_name(type_node), args: [], location: nil) - when call_node?(type_node, name: :[], receiver: -> (_) { true }) - # The type_node represents a type application - type = type_of(type_node.children[0], variables: variables) - type.is_a?(Types::ClassInstance) or raise - each_arg(type_node.children[2]) do |arg| - type.args << type_of(arg, variables: variables) - end - - type - when call_node?(type_node, name: :type_parameter) - name = each_arg(type_node.children[2]).to_a[0].children[0] - Types::Variable.new(name: name, location: nil) - when call_node?(type_node, name: :any) - types = each_arg(type_node.children[2]).to_a.map {|node| type_of(node, variables: variables) } - Types::Union.new(types: types, location: nil) - when call_node?(type_node, name: :all) - types = each_arg(type_node.children[2]).to_a.map {|node| type_of(node, variables: variables) } - Types::Intersection.new(types: types, location: nil) - when call_node?(type_node, name: :untyped) - Types::Bases::Any.new(location: nil) - when call_node?(type_node, name: :nilable) - type = type_of each_arg(type_node.children[2]).to_a[0], variables: variables - Types::Optional.new(type: type, location: nil) - when call_node?(type_node, name: :self_type) - Types::Bases::Self.new(location: nil) - when call_node?(type_node, name: :attached_class) - Types::Bases::Instance.new(location: nil) - when call_node?(type_node, name: :noreturn) - Types::Bases::Bottom.new(location: nil) - when call_node?(type_node, name: :class_of) - type = type_of each_arg(type_node.children[2]).to_a[0], variables: variables - case type - when Types::ClassInstance - Types::ClassSingleton.new(name: type.name, location: nil) - else - STDERR.puts "Unexpected type for `class_of`: #{type}" - Types::Bases::Any.new(location: nil) - end - when type_node.type == :ARRAY, type_node.type == :LIST - types = each_arg(type_node).map {|node| type_of(node, variables: variables) } - Types::Tuple.new(types: types, location: nil) - else - if proc_type?(type_node) - method_type = method_type(nil, type_node, variables: variables, overloads: 1) or raise - Types::Proc.new(type: method_type.type, block: nil, location: nil, self_type: nil) - else - STDERR.puts "Unexpected type_node:" - PP.pp type_node, STDERR - Types::Bases::Any.new(location: nil) - end + def current_module + modules.last end - end - def proc_type?(type_node) - if call_node?(type_node, name: :proc) - true - else - type_node.type == :CALL && proc_type?(type_node.children[0]) + def current_module! + current_module or raise end - end - - def call_node?(node, name:, receiver: -> (node) { node.type == :CONST && node.children[0] == :T }, args: -> (node) { true }) - node.type == :CALL && receiver[node.children[0]] && name == node.children[1] && args[node.children[2]] - end - - def const_to_name(node) - case node.type - when :CONST - TypeName.new(name: node.children[0], namespace: Namespace.empty) - when :COLON2 - if node.children[0] - namespace = const_to_name(node.children[0]).to_namespace - else - namespace = Namespace.empty - end - type_name = TypeName.new(name: node.children[1], namespace: namespace) - - case type_name.to_s - when "T::Array" - BuiltinNames::Array.name - when "T::Hash" - BuiltinNames::Hash.name - when "T::Range" - BuiltinNames::Range.name - when "T::Enumerator" - BuiltinNames::Enumerator.name - when "T::Enumerable" - BuiltinNames::Enumerable.name - when "T::Set" - BuiltinNames::Set.name + def push_sig(node) + if last_sig = @last_sig + last_sig << node else - type_name - end - when :COLON3 - TypeName.new(name: node.children[0], namespace: Namespace.root) - else - raise "Unexpected node type: #{node.type}" - end - end - - def each_arg(array, &block) - if block_given? - if array&.type == :ARRAY || array&.type == :LIST - array.children.each do |arg| - if arg - yield arg - end - end + @last_sig = [node] end - else - enum_for :each_arg, array end - end - def each_child(node) - node.children.each do |child| - if child.is_a?(RubyVM::AbstractSyntaxTree::Node) - yield child + def pop_sig + @last_sig.tap do + @last_sig = nil end end end - - def node_to_hash(node) - if node&.type == :HASH - # @type var hash: Hash[Symbol, untyped] - hash = {} - - each_arg(node.children[0]).each_slice(2) do |var, type| - var or raise - - if (name = symbol_literal_node?(var)) && type - hash[name] = type - end - end - - hash - end - end end end end diff --git a/lib/rbs/prototype/rbi/prism.rb b/lib/rbs/prototype/rbi/prism.rb new file mode 100644 index 0000000000..367459b6bb --- /dev/null +++ b/lib/rbs/prototype/rbi/prism.rb @@ -0,0 +1,558 @@ +# frozen_string_literal: true + +module RBS + module Prototype + module RBI + class Prism < Base + def parse(string) + result = ::Prism.parse(string) + comments = build_comments_prism(result.comments, include_trailing: true) + process result.value, comments: comments + end + + def process(node, outer: [], comments:) + case node.type + when :program_node + process(node.statements, outer: outer, comments: comments) + when :statements_node + node.body.each do |child| + process(child, outer: outer + [node], comments: comments) + end + when :begin_node + if node.statements + process(node.statements, outer: outer + [node], comments: comments) + end + when :class_node + comment = comments[node.location.start_line - 1] + push_class(node.constant_path, node.superclass, comment: comment) do + if node.body + process(node.body, outer: outer + [node], comments: comments) + end + end + when :module_node + comment = comments[node.location.start_line - 1] + push_module(node.constant_path, comment: comment) do + if node.body + process(node.body, outer: outer + [node], comments: comments) + end + end + when :call_node + if node.receiver.nil? && node.block.nil? + handle_fcall(node, outer: outer, comments: comments) + elsif node.receiver.nil? && node.block + if node.name == :sig + handle_sig(node) + else + if node.block.is_a?(::Prism::BlockNode) && node.block.body + process(node.block.body, outer: outer + [node], comments: comments) + end + end + else + node.each_child_node do |child| + process(child, outer: outer + [node], comments: comments) + end + end + when :def_node + sigs = pop_sig + + if sigs + comment = join_comments(sigs, comments) + kind = node.receiver ? :singleton : :instance #: AST::Members::MethodDefinition::kind + types = sigs.map { |sig| method_type(node, sig, variables: current_module!.type_params, overloads: sigs.size) }.compact + + current_module!.members << AST::Members::MethodDefinition.new( + name: node.name, + location: nil, + annotations: [], + overloads: types.map { |type| AST::Members::MethodDefinition::Overload.new(annotations: [], method_type: type) }, + kind: kind, + comment: comment, + overloading: false, + visibility: nil + ) + end + when :constant_write_node + handle_cdecl(node) + when :multi_write_node + node.lefts.each do |target| + if target.type == :constant_target_node + name = + if current_module + TypeName.new(namespace: current_namespace, name: target.name) + else + TypeName.new(namespace: Namespace.empty, name: target.name) + end + + decls << AST::Declarations::Constant.new( + name: name, type: Types::Bases::Any.new(location: nil), + location: nil, comment: nil, annotations: [] + ) + end + end + when :alias_method_node + current_module!.members << AST::Members::Alias.new( + new_name: node.new_name.unescaped.to_sym, + old_name: node.old_name.unescaped.to_sym, + location: nil, + annotations: [], + kind: :instance, + comment: nil + ) + else + node.each_child_node do |child| + process(child, outer: outer + [node], comments: comments) + end + end + end + + private + + def handle_fcall(node, outer:, comments:) + args = node.arguments&.arguments || [] + + case node.name + when :include + args.each do |arg| + case arg.type + when :constant_read_node, :constant_path_node + name = const_to_name(arg) + current_module!.members << AST::Members::Include.new( + name: name, args: [], annotations: [], + location: nil, comment: nil + ) + end + end + when :extend + args.each do |arg| + case arg.type + when :constant_read_node, :constant_path_node + name = const_to_name(arg) + unless name.to_s == "T::Generic" || name.to_s == "T::Sig" + current_module!.members << AST::Members::Extend.new( + name: name, args: [], annotations: [], + location: nil, comment: nil + ) + end + end + end + when :alias_method + if args.size >= 2 + new_name = symbol_value(args[0]) + old_name = symbol_value(args[1]) + if new_name && old_name + current_module!.members << AST::Members::Alias.new( + new_name: new_name, old_name: old_name, + location: nil, annotations: [], + kind: :instance, comment: nil + ) + end + end + end + end + + def handle_sig(node) + block = node.block + return unless block.is_a?(::Prism::BlockNode) + + body = block.body + return unless body + + sig_chain = body.is_a?(::Prism::StatementsNode) ? body.body.last : body + push_sig(sig_chain) if sig_chain + end + + def join_comments(sig_nodes, comments) + cs = sig_nodes.map { |node| comments[node.location.start_line - 1] }.compact + AST::Comment.new(string: cs.map(&:string).join("\n"), location: nil) + end + + def method_type(def_node, sig_node, variables:, overloads:) + return nil unless sig_node + + method_type = MethodType.new( + type: Types::Function.empty(Types::Bases::Any.new(location: nil)), + block: nil, + location: nil, + type_params: [] + ) + + walk_sig_chain(def_node, sig_node, method_type, variables: variables, overloads: overloads) + end + + # Walk a sig chain recursively. The chain looks like: + # returns(String) - call_node name=returns, receiver=nil + # params(x: T).returns(Y) - call_node name=returns, receiver=call_node(name=params) + def walk_sig_chain(def_node, node, method_type, variables:, overloads:) + return method_type unless node + + case node.type + when :call_node + if node.receiver&.type == :call_node + method_type = walk_sig_chain(def_node, node.receiver, method_type, variables: variables, overloads: overloads) + end + + args = node.arguments&.arguments || [] + + case node.name + when :returns + if args[0] + return_type = type_of(args[0], variables: variables) + method_type = method_type.update(type: method_type.type.with_return_type(return_type)) + end + when :params + if def_node + method_type = parse_params(def_node, args, method_type, variables: variables, overloads: overloads) + else + hash = args_to_hash(args[0], variables: variables) + required_positionals = hash.map do |name, type| + Types::Function::Param.new(name: name, type: type) + end + if method_type.type.is_a?(RBS::Types::Function) + method_type = method_type.update(type: method_type.type.update(required_positionals: required_positionals)) + end + end + when :type_parameters + type_params = args.filter_map do |arg| + if (name = symbol_value(arg)) + AST::TypeParam.new( + name: name, variance: :invariant, + upper_bound: nil, lower_bound: nil, + location: nil, default_type: nil + ) + end + end + method_type = method_type.update(type_params: type_params) + when :void + method_type = method_type.update(type: method_type.type.with_return_type(Types::Bases::Void.new(location: nil))) + when :proc + # T.proc - continue, the chain will fill in params/returns + end + end + + method_type + end + + def parse_params(def_node, sig_args, method_type, variables:, overloads:) + vars = args_to_hash(sig_args[0], variables: variables) + params = def_node.parameters + + required_positionals = [] #: Array[Types::Function::Param] + optional_positionals = [] #: Array[Types::Function::Param] + rest_positionals = nil #: Types::Function::Param | nil + trailing_positionals = [] #: Array[Types::Function::Param] + required_keywords = {} #: Hash[Symbol, Types::Function::Param] + optional_keywords = {} #: Hash[Symbol, Types::Function::Param] + rest_keywords = nil #: Types::Function::Param | nil + method_block = nil #: Types::Block | nil + + if params + params.requireds.each do |req| #: Prism::RequiredParameterNode | Prism::MultiTargetNode + name = req.is_a?(::Prism::RequiredParameterNode) ? req.name : nil + type = (name && vars[name]) || Types::Bases::Any.new(location: nil) + required_positionals << Types::Function::Param.new(type: type, name: name) + end + + params.optionals.each do |opt| + type = vars[opt.name] + if type + optional_positionals << Types::Function::Param.new(type: type, name: opt.name) + end + end + + if (rest = params.rest).is_a?(::Prism::RestParameterNode) && (rest_name = rest.name) + if (type = vars[rest_name]) + rest_positionals = Types::Function::Param.new(type: type, name: rest_name) + end + end + + params.posts.each do |post| #: Prism::RequiredParameterNode | Prism::MultiTargetNode + name = post.is_a?(::Prism::RequiredParameterNode) ? post.name : nil + if name && (type = vars[name]) + trailing_positionals << Types::Function::Param.new(type: type, name: name) + end + end + + params.keywords.each do |kw| + case kw.type + when :required_keyword_parameter_node + if (type = vars[kw.name]) + required_keywords[kw.name] = Types::Function::Param.new(type: type, name: kw.name) + end + when :optional_keyword_parameter_node + if (type = vars[kw.name]) + optional_keywords[kw.name] = Types::Function::Param.new(type: type, name: kw.name) + end + end + end + + if (keyword_rest = params.keyword_rest).is_a?(::Prism::KeywordRestParameterNode) && (kw_rest_name = keyword_rest.name) + if (type = vars[kw_rest_name]) + rest_keywords = Types::Function::Param.new(type: type, name: kw_rest_name) + end + end + + if (block_param = params.block).is_a?(::Prism::BlockParameterNode) + block_name = block_param.name + if block_name && (type = vars[block_name]) + if type.is_a?(Types::Proc) + method_block = Types::Block.new(required: true, type: type.type, self_type: nil) + elsif type.is_a?(Types::Bases::Any) + method_block = Types::Block.new( + required: true, + type: Types::Function.empty(Types::Bases::Any.new(location: nil)), + self_type: nil + ) + elsif type.is_a?(Types::Optional) && (proc_type = type.type).is_a?(Types::Proc) + method_block = Types::Block.new(required: false, type: proc_type.type, self_type: nil) + else + STDERR.puts "Unexpected block type: #{type}" + method_block = Types::Block.new( + required: true, + type: Types::Function.empty(Types::Bases::Any.new(location: nil)), + self_type: nil + ) + end + elsif overloads == 1 + method_block = Types::Block.new( + required: false, + type: Types::Function.empty(Types::Bases::Any.new(location: nil)), + self_type: nil + ) + end + end + end + + if method_type.type.is_a?(Types::Function) + method_type.update( + type: method_type.type.update( + required_positionals: required_positionals, + optional_positionals: optional_positionals, + rest_positionals: rest_positionals, + trailing_positionals: trailing_positionals, + required_keywords: required_keywords, + optional_keywords: optional_keywords, + rest_keywords: rest_keywords + ), + block: method_block + ) + else + method_type + end + end + + def args_to_hash(node, variables:) + return {} unless node #: Hash[Symbol, Types::t] + + case node.type + when :keyword_hash_node, :hash_node + hash = {} #: Hash[Symbol, Types::t] + node.elements.each do |elem| + if elem.is_a?(::Prism::AssocNode) && (name = symbol_value(elem.key)) + hash[name] = type_of(elem.value, variables: variables) + end + end + hash + else + {} + end + end + + def type_of(node, variables:) + type = type_of0(node, variables: variables) + + case + when type.is_a?(Types::ClassInstance) && type.name.name == BuiltinNames::BasicObject.name.name + Types::Bases::Any.new(location: nil) + when type.is_a?(Types::ClassInstance) && type.name.to_s == "T::Boolean" + Types::Bases::Bool.new(location: nil) + else + type + end + end + + def type_of0(node, variables:) + case node.type + when :constant_read_node + if variables.any? { |tp| tp.name == node.name } + Types::Variable.new(name: node.name, location: nil) + else + Types::ClassInstance.new(name: const_to_name(node), args: [], location: nil) + end + + when :constant_path_node + Types::ClassInstance.new(name: const_to_name(node), args: [], location: nil) + + when :call_node + if t_call?(node) + handle_t_call(node, variables: variables) + elsif node.name == :[] && node.receiver + type = type_of(node.receiver, variables: variables) + type.is_a?(Types::ClassInstance) or raise + + (node.arguments&.arguments || []).each do |arg| + type.args << type_of(arg, variables: variables) + end + + type + elsif proc_type?(node) + mt = walk_sig_chain(nil, node, MethodType.new( + type: Types::Function.empty(Types::Bases::Any.new(location: nil)), + block: nil, location: nil, type_params: [] + ), variables: variables, overloads: 1) + Types::Proc.new(type: mt.type, block: nil, location: nil, self_type: nil) + else + STDERR.puts "Unexpected type_node:" + PP.pp node, STDERR + Types::Bases::Any.new(location: nil) + end + + when :array_node + types = node.elements.map { |e| type_of(e, variables: variables) } + Types::Tuple.new(types: types, location: nil) + + else + STDERR.puts "Unexpected type_node:" + PP.pp node, STDERR + Types::Bases::Any.new(location: nil) + end + end + + def t_call?(node) + node.type == :call_node && node.receiver&.type == :constant_read_node && node.receiver.name == :T + end + + def handle_t_call(node, variables:) + args = node.arguments&.arguments || [] + + case node.name + when :any + types = args.map { |a| type_of(a, variables: variables) } + Types::Union.new(types: types, location: nil) + when :all + types = args.map { |a| type_of(a, variables: variables) } + Types::Intersection.new(types: types, location: nil) + when :untyped + Types::Bases::Any.new(location: nil) + when :nilable + type = type_of(args[0], variables: variables) + Types::Optional.new(type: type, location: nil) + when :self_type + Types::Bases::Self.new(location: nil) + when :attached_class + Types::Bases::Instance.new(location: nil) + when :noreturn + Types::Bases::Bottom.new(location: nil) + when :class_of + type = type_of(args[0], variables: variables) + if type.is_a?(Types::ClassInstance) + Types::ClassSingleton.new(name: type.name, location: nil) + else + STDERR.puts "Unexpected type for `class_of`: #{type}" + Types::Bases::Any.new(location: nil) + end + when :type_parameter + name = symbol_value(args[0]) + Types::Variable.new(name: name || raise, location: nil) + when :proc + mt = walk_sig_chain(nil, node, MethodType.new( + type: Types::Function.empty(Types::Bases::Any.new(location: nil)), + block: nil, location: nil, type_params: [] + ), variables: variables, overloads: 1) + Types::Proc.new(type: mt.type, block: nil, location: nil, self_type: nil) + else + Types::Bases::Any.new(location: nil) + end + end + + def proc_type?(node) + return true if t_call?(node) && node.name == :proc + node.type == :call_node && node.receiver && proc_type?(node.receiver) + end + + def handle_cdecl(node) + value = node.value + + if value.is_a?(::Prism::CallNode) && value.receiver.nil? && value.name == :type_member + args = value.arguments&.arguments || [] + has_fixed = + args.any? do |a| + (a.is_a?(::Prism::KeywordHashNode) || a.is_a?(::Prism::HashNode)) && + a.elements.any? { |e| e.is_a?(::Prism::AssocNode) && symbol_value(e.key) == :fixed } + end + + unless has_fixed + variance = :invariant #: AST::TypeParam::variance + if args[0] && (v = symbol_value(args[0])) + variance = + case v + when :out then :covariant #: AST::TypeParam::variance + when :in then :contravariant #: AST::TypeParam::variance + else :invariant #: AST::TypeParam::variance + end + end + + current_module!.type_params << AST::TypeParam.new( + name: node.name, + variance: variance, + location: nil, + upper_bound: nil, + lower_bound: nil, + default_type: nil + ) + end + else + const_name = TypeName.new(namespace: current_namespace, name: node.name) + + type = + if value.is_a?(::Prism::CallNode) && (recv = value.receiver).is_a?(::Prism::ConstantReadNode) && + recv.name == :T && value.name == :let + type_arg = value.arguments&.arguments&.[](1) + if type_arg + type_of(type_arg, variables: current_module&.type_params || []) + else + Types::Bases::Any.new(location: nil) + end + else + Types::Bases::Any.new(location: nil) + end + + decls << AST::Declarations::Constant.new( + name: const_name, type: type, + location: nil, comment: nil, annotations: [] + ) + end + end + + def const_to_name(node) + case node.type + when :constant_read_node + TypeName.new(name: node.name, namespace: Namespace.empty) + when :constant_path_node + if node.parent.nil? + TypeName.new(name: node.name || raise, namespace: Namespace.root) + else + namespace = const_to_name(node.parent).to_namespace + type_name = TypeName.new(name: node.name || raise, namespace: namespace) + + case type_name.to_s + when "T::Array" then BuiltinNames::Array.name + when "T::Hash" then BuiltinNames::Hash.name + when "T::Range" then BuiltinNames::Range.name + when "T::Enumerator" then BuiltinNames::Enumerator.name + when "T::Enumerable" then BuiltinNames::Enumerable.name + when "T::Set" then BuiltinNames::Set.name + else type_name + end + end + else + raise "Unexpected node type for const: #{node.type}" + end + end + + def symbol_value(node) + node.unescaped.to_sym if node.is_a?(::Prism::SymbolNode) + end + end + end + end +end diff --git a/lib/rbs/prototype/rbi/ruby_vm.rb b/lib/rbs/prototype/rbi/ruby_vm.rb new file mode 100644 index 0000000000..80ddfb2828 --- /dev/null +++ b/lib/rbs/prototype/rbi/ruby_vm.rb @@ -0,0 +1,550 @@ +# frozen_string_literal: true + +module RBS + module Prototype + module RBI + class RubyVM < Base + include RubyVMHelpers + + def parse(string) + comments = parse_comments(string, include_trailing: true) + process ::RubyVM::AbstractSyntaxTree.parse(string), comments: comments + end + + def join_comments(nodes, comments) + cs = nodes.map {|node| comments[node.first_lineno - 1] }.compact + AST::Comment.new(string: cs.map(&:string).join("\n"), location: nil) + end + + def process(node, outer: [], comments:) + case node.type + when :CLASS + comment = comments[node.first_lineno - 1] + push_class node.children[0], node.children[1], comment: comment do + process node.children[2], outer: outer + [node], comments: comments + end + when :MODULE + comment = comments[node.first_lineno - 1] + push_module node.children[0], comment: comment do + process node.children[1], outer: outer + [node], comments: comments + end + when :FCALL + case node.children[0] + when :include + each_arg node.children[1] do |arg| + if arg.type == :CONST || arg.type == :COLON2 || arg.type == :COLON3 + name = const_to_name(arg) + include_member = AST::Members::Include.new( + name: name, + args: [], + annotations: [], + location: nil, + comment: nil + ) + current_module!.members << include_member + end + end + when :extend + each_arg node.children[1] do |arg| + if arg.type == :CONST || arg.type == :COLON2 + name = const_to_name(arg) + unless name.to_s == "T::Generic" || name.to_s == "T::Sig" + member = AST::Members::Extend.new( + name: name, + args: [], + annotations: [], + location: nil, + comment: nil + ) + current_module!.members << member + end + end + end + when :sig + out = outer.last or raise + push_sig out.children.last.children.last + when :alias_method + new, old = each_arg(node.children[1]).map {|x| x.children[0] } + current_module!.members << AST::Members::Alias.new( + new_name: new, + old_name: old, + location: nil, + annotations: [], + kind: :instance, + comment: nil + ) + end + when :DEFS + sigs = pop_sig + + if sigs + comment = join_comments(sigs, comments) + + args = node.children[2] + types = sigs.map {|sig| method_type(args, sig, variables: current_module!.type_params, overloads: sigs.size) }.compact + + current_module!.members << AST::Members::MethodDefinition.new( + name: node.children[1], + location: nil, + annotations: [], + overloads: types.map {|type| AST::Members::MethodDefinition::Overload.new(annotations: [], method_type: type) }, + kind: :singleton, + comment: comment, + overloading: false, + visibility: nil + ) + end + + when :DEFN + sigs = pop_sig + + if sigs + comment = join_comments(sigs, comments) + + args = node.children[1] + types = sigs.map {|sig| method_type(args, sig, variables: current_module!.type_params, overloads: sigs.size) }.compact + + current_module!.members << AST::Members::MethodDefinition.new( + name: node.children[0], + location: nil, + annotations: [], + overloads: types.map {|type| AST::Members::MethodDefinition::Overload.new(annotations: [], method_type: type) }, + kind: :instance, + comment: comment, + overloading: false, + visibility: nil + ) + end + + when :CDECL + if (send = node.children.last) && send.type == :FCALL && send.children[0] == :type_member + unless each_arg(send.children[1]).any? {|node| + node.type == :HASH && + each_arg(node.children[0]).each_slice(2).any? {|a, _| symbol_literal_node?(a) == :fixed } + } + # @type var variance: AST::TypeParam::variance? + if (a0 = each_arg(send.children[1]).to_a[0]) && (v = symbol_literal_node?(a0)) + variance = case v + when :out + :covariant + when :in + :contravariant + end + end + + current_module!.type_params << AST::TypeParam.new( + name: node.children[0], + variance: variance || :invariant, + location: nil, + upper_bound: nil, + lower_bound: nil, + default_type: nil + ) + end + else + name = node.children[0].yield_self do |n| + if n.is_a?(Symbol) + TypeName.new(namespace: current_namespace, name: n) + else + const_to_name(n) + end + end + value_node = node.children.last + type = if value_node && value_node.type == :CALL && value_node.children[1] == :let + type_node = each_arg(value_node.children[2]).to_a[1] + type_of type_node, variables: current_module&.type_params || [] + else + Types::Bases::Any.new(location: nil) + end + decls << AST::Declarations::Constant.new( + name: name, + type: type, + location: nil, + comment: nil, + annotations: [] + ) + end + when :ALIAS + current_module!.members << AST::Members::Alias.new( + new_name: node.children[0].children[0], + old_name: node.children[1].children[0], + location: nil, + annotations: [], + kind: :instance, + comment: nil + ) + else + each_child node do |child| + process child, outer: outer + [node], comments: comments + end + end + end + + def method_type(args_node, type_node, variables:, overloads:) + if type_node + if type_node.type == :CALL + method_type = method_type(args_node, type_node.children[0], variables: variables, overloads: overloads) or raise + else + method_type = MethodType.new( + type: Types::Function.empty(Types::Bases::Any.new(location: nil)), + block: nil, + location: nil, + type_params: [] + ) + end + + name, args = case type_node.type + when :CALL + [ + type_node.children[1], + type_node.children[2] + ] + when :FCALL, :VCALL + [ + type_node.children[0], + type_node.children[1] + ] + end + + case name + when :returns + return_type = each_arg(args).to_a[0] + method_type.update(type: method_type.type.with_return_type(type_of(return_type, variables: variables))) + when :params + if args_node + parse_params(args_node, args, method_type, variables: variables, overloads: overloads) + else + vars = (node_to_hash(each_arg(args).to_a[0]) || {}).transform_values {|value| type_of(value, variables: variables) } + + required_positionals = vars.map do |name, type| + Types::Function::Param.new(name: name, type: type) + end + + if method_type.type.is_a?(RBS::Types::Function) + method_type.update(type: method_type.type.update(required_positionals: required_positionals)) + else + method_type + end + end + when :type_parameters + type_params = [] #: Array[AST::TypeParam] + + each_arg args do |node| + if name = symbol_literal_node?(node) + type_params << AST::TypeParam.new( + name: name, + variance: :invariant, + upper_bound: nil, + lower_bound: nil, + location: nil, + default_type: nil + ) + end + end + + method_type.update(type_params: type_params) + when :void + method_type.update(type: method_type.type.with_return_type(Types::Bases::Void.new(location: nil))) + when :proc + method_type + else + method_type + end + end + end + + def parse_params(args_node, args, method_type, variables:, overloads:) + vars = (node_to_hash(each_arg(args).to_a[0]) || {}).transform_values {|value| type_of(value, variables: variables) } + + required_positionals = [] #: Array[Types::Function::Param] + optional_positionals = [] #: Array[Types::Function::Param] + rest_positionals = nil #: Types::Function::Param | nil + trailing_positionals = [] #: Array[Types::Function::Param] + required_keywords = {} #: Hash[Symbol, Types::Function::Param] + optional_keywords = {} #: Hash[Symbol, Types::Function::Param] + rest_keywords = nil #: Types::Function::Param | nil + + var_names = args_node.children[0] + pre_num, _pre_init, opt, _first_post, post_num, _post_init, rest, kw, kwrest, block = args_node.children[1].children + + pre_num.times.each do |i| + name = var_names[i] + type = vars[name] || Types::Bases::Any.new(location: nil) + required_positionals << Types::Function::Param.new(type: type, name: name) + end + + index = pre_num + while opt + name = var_names[index] + if (type = vars[name]) + optional_positionals << Types::Function::Param.new(type: type, name: name) + end + index += 1 + opt = opt.children[1] + end + + if rest + name = var_names[index] + if (type = vars[name]) + rest_positionals = Types::Function::Param.new(type: type, name: name) + end + index += 1 + end + + post_num.times do |i| + name = var_names[i+index] + if (type = vars[name]) + trailing_positionals << Types::Function::Param.new(type: type, name: name) + end + index += 1 + end + + while kw + name, value = kw.children[0].children + if (type = vars[name]) + if value + optional_keywords[name] = Types::Function::Param.new(type: type, name: name) + else + required_keywords[name] = Types::Function::Param.new(type: type, name: name) + end + end + + kw = kw.children[1] + end + + if kwrest + name = kwrest.children[0] + if (type = vars[name]) + rest_keywords = Types::Function::Param.new(type: type, name: name) + end + end + + method_block = nil + if block + if (type = vars[block]) + if type.is_a?(Types::Proc) + method_block = Types::Block.new(required: true, type: type.type, self_type: nil) + elsif type.is_a?(Types::Bases::Any) + method_block = Types::Block.new( + required: true, + type: Types::Function.empty(Types::Bases::Any.new(location: nil)), + self_type: nil + ) + elsif type.is_a?(Types::Optional) && (proc_type = type.type).is_a?(Types::Proc) + method_block = Types::Block.new(required: false, type: proc_type.type, self_type: nil) + else + STDERR.puts "Unexpected block type: #{type}" + PP.pp args_node, STDERR + method_block = Types::Block.new( + required: true, + type: Types::Function.empty(Types::Bases::Any.new(location: nil)), + self_type: nil + ) + end + else + if overloads == 1 + method_block = Types::Block.new( + required: false, + type: Types::Function.empty(Types::Bases::Any.new(location: nil)), + self_type: nil + ) + end + end + end + + if method_type.type.is_a?(Types::Function) + method_type.update( + type: method_type.type.update( + required_positionals: required_positionals, + optional_positionals: optional_positionals, + rest_positionals: rest_positionals, + trailing_positionals: trailing_positionals, + required_keywords: required_keywords, + optional_keywords: optional_keywords, + rest_keywords: rest_keywords + ), + block: method_block + ) + else + method_type + end + end + + def type_of(type_node, variables:) + type = type_of0(type_node, variables: variables) + + case + when type.is_a?(Types::ClassInstance) && type.name.name == BuiltinNames::BasicObject.name.name + Types::Bases::Any.new(location: nil) + when type.is_a?(Types::ClassInstance) && type.name.to_s == "T::Boolean" + Types::Bases::Bool.new(location: nil) + else + type + end + end + + def type_of0(type_node, variables:) + case + when type_node.type == :CONST + if variables.include?(type_node.children[0]) + Types::Variable.new(name: type_node.children[0], location: nil) + else + Types::ClassInstance.new(name: const_to_name(type_node), args: [], location: nil) + end + when type_node.type == :COLON2 || type_node.type == :COLON3 + Types::ClassInstance.new(name: const_to_name(type_node), args: [], location: nil) + when call_node?(type_node, name: :[], receiver: -> (_) { true }) + type = type_of(type_node.children[0], variables: variables) + type.is_a?(Types::ClassInstance) or raise + + each_arg(type_node.children[2]) do |arg| + type.args << type_of(arg, variables: variables) + end + + type + when call_node?(type_node, name: :type_parameter) + name = each_arg(type_node.children[2]).to_a[0].children[0] + Types::Variable.new(name: name, location: nil) + when call_node?(type_node, name: :any) + types = each_arg(type_node.children[2]).to_a.map {|node| type_of(node, variables: variables) } + Types::Union.new(types: types, location: nil) + when call_node?(type_node, name: :all) + types = each_arg(type_node.children[2]).to_a.map {|node| type_of(node, variables: variables) } + Types::Intersection.new(types: types, location: nil) + when call_node?(type_node, name: :untyped) + Types::Bases::Any.new(location: nil) + when call_node?(type_node, name: :nilable) + type = type_of each_arg(type_node.children[2]).to_a[0], variables: variables + Types::Optional.new(type: type, location: nil) + when call_node?(type_node, name: :self_type) + Types::Bases::Self.new(location: nil) + when call_node?(type_node, name: :attached_class) + Types::Bases::Instance.new(location: nil) + when call_node?(type_node, name: :noreturn) + Types::Bases::Bottom.new(location: nil) + when call_node?(type_node, name: :class_of) + type = type_of each_arg(type_node.children[2]).to_a[0], variables: variables + case type + when Types::ClassInstance + Types::ClassSingleton.new(name: type.name, location: nil) + else + STDERR.puts "Unexpected type for `class_of`: #{type}" + Types::Bases::Any.new(location: nil) + end + when type_node.type == :ARRAY, type_node.type == :LIST + types = each_arg(type_node).map {|node| type_of(node, variables: variables) } + Types::Tuple.new(types: types, location: nil) + else + if proc_type?(type_node) + method_type = method_type(nil, type_node, variables: variables, overloads: 1) or raise + Types::Proc.new(type: method_type.type, block: nil, location: nil, self_type: nil) + else + STDERR.puts "Unexpected type_node:" + PP.pp type_node, STDERR + Types::Bases::Any.new(location: nil) + end + end + end + + def proc_type?(type_node) + if call_node?(type_node, name: :proc) + true + else + type_node.type == :CALL && proc_type?(type_node.children[0]) + end + end + + def call_node?(node, name:, receiver: -> (node) { node.type == :CONST && node.children[0] == :T }, args: -> (node) { true }) + node.type == :CALL && receiver[node.children[0]] && name == node.children[1] && args[node.children[2]] + end + + def const_to_name(node) + case node.type + when :CONST + TypeName.new(name: node.children[0], namespace: Namespace.empty) + when :COLON2 + if node.children[0] + namespace = const_to_name(node.children[0]).to_namespace + else + namespace = Namespace.empty + end + + type_name = TypeName.new(name: node.children[1], namespace: namespace) + + case type_name.to_s + when "T::Array" then BuiltinNames::Array.name + when "T::Hash" then BuiltinNames::Hash.name + when "T::Range" then BuiltinNames::Range.name + when "T::Enumerator" then BuiltinNames::Enumerator.name + when "T::Enumerable" then BuiltinNames::Enumerable.name + when "T::Set" then BuiltinNames::Set.name + else type_name + end + when :COLON3 + TypeName.new(name: node.children[0], namespace: Namespace.root) + when :constant_read_node + TypeName.new(name: node.name, namespace: Namespace.empty) + when :constant_path_node + if node.parent.nil? + TypeName.new(name: node.name || raise, namespace: Namespace.root) + else + namespace = const_to_name(node.parent).to_namespace + type_name = TypeName.new(name: node.name || raise, namespace: namespace) + + case type_name.to_s + when "T::Array" then BuiltinNames::Array.name + when "T::Hash" then BuiltinNames::Hash.name + when "T::Range" then BuiltinNames::Range.name + when "T::Enumerator" then BuiltinNames::Enumerator.name + when "T::Enumerable" then BuiltinNames::Enumerable.name + when "T::Set" then BuiltinNames::Set.name + else type_name + end + end + else + raise "Unexpected node type: #{node.type}" + end + end + + def each_arg(array, &block) + if block_given? + if array && (array.type == :ARRAY || array.type == :LIST) + array.children.each do |arg| + if arg + yield arg + end + end + end + else + enum_for :each_arg, array + end + end + + def each_child(node) + node.children.each do |child| + if child.is_a?(::RubyVM::AbstractSyntaxTree::Node) + yield child + end + end + end + + def node_to_hash(node) + if node && node.type == :HASH + # @type var hash: Hash[Symbol, untyped] + hash = {} + + each_arg(node.children[0]).each_slice(2) do |var, type| + var or raise + + if (name = symbol_literal_node?(var)) && type + hash[name] = type + end + end + + hash + end + end + end + end + end +end diff --git a/lib/rbs/prototype/helpers.rb b/lib/rbs/prototype/ruby_vm_helpers.rb similarity index 61% rename from lib/rbs/prototype/helpers.rb rename to lib/rbs/prototype/ruby_vm_helpers.rb index df28a8eed9..1c62a68395 100644 --- a/lib/rbs/prototype/helpers.rb +++ b/lib/rbs/prototype/ruby_vm_helpers.rb @@ -2,66 +2,11 @@ module RBS module Prototype - module Helpers + # RubyVM::AbstractSyntaxTree helper methods shared between RB::RubyVM, + # RBI::RubyVM, NodeUsage, and Runtime. + module RubyVMHelpers private - # Prism can't parse Ruby 3.2 code - if RUBY_VERSION >= "3.3" - def parse_comments(string, include_trailing:) - Prism.parse_comments(string, version: "current").yield_self do |prism_comments| # steep:ignore UnexpectedKeywordArgument - prism_comments.each_with_object({}) do |comment, hash| #$ Hash[Integer, AST::Comment] - # Skip EmbDoc comments - next unless comment.is_a?(Prism::InlineComment) - # skip like `module Foo # :nodoc:` - next if comment.trailing? && !include_trailing - - line = comment.location.start_line - body = "#{comment.location.slice}\n" - body = body[2..-1] or raise - body = "\n" if body.empty? - - comment = AST::Comment.new(string: body, location: nil) - if prev_comment = hash.delete(line - 1) - hash[line] = AST::Comment.new(string: prev_comment.string + comment.string, - location: nil) - else - hash[line] = comment - end - end - end - end - else - require "ripper" - def parse_comments(string, include_trailing:) - Ripper.lex(string).yield_self do |tokens| - code_lines = {} #: Hash[Integer, bool] - tokens.each.with_object({}) do |token, hash| #$ Hash[Integer, AST::Comment] - case token[1] - when :on_sp, :on_ignored_nl - # skip - when :on_comment - line = token[0][0] - # skip like `module Foo # :nodoc:` - next if code_lines[line] && !include_trailing - body = token[2][2..-1] or raise - - body = "\n" if body.empty? - - comment = AST::Comment.new(string: body, location: nil) - if prev_comment = hash.delete(line - 1) - hash[line] = AST::Comment.new(string: prev_comment.string + comment.string, - location: nil) - else - hash[line] = comment - end - else - code_lines[token[0][0]] = true - end - end - end - end - end - def block_from_body(node) _, args_node, body_node = node.children _pre_num, _pre_init, _opt, _first_post, _post_num, _post_init, _rest, _kw, _kwrest, block_var = args_from_node(args_node) @@ -133,7 +78,6 @@ def block_from_body(node) function = Types::UntypedFunction.new(return_type: untyped) end - Types::Block.new(required: required, type: function, self_type: nil) end end @@ -188,10 +132,6 @@ def symbol_literal_node?(node) node.children[0] end end - - def untyped - @untyped ||= Types::Bases::Any.new(location: nil) - end end end end diff --git a/lib/rbs/prototype/runtime.rb b/lib/rbs/prototype/runtime.rb index 45114ec2ef..bb7ca49a6a 100644 --- a/lib/rbs/prototype/runtime.rb +++ b/lib/rbs/prototype/runtime.rb @@ -66,7 +66,6 @@ def mixin_decls(type_name) end private_constant :Todo - include Prototype::Helpers include Runtime::Helpers attr_reader :patterns @@ -659,18 +658,40 @@ def type_params(mod) end end - def block_from_ast_of(method) - begin - ast = RubyVM::AbstractSyntaxTree.of(method) - rescue ArgumentError - return # When the method is defined in eval - rescue RuntimeError => error - raise unless error.message.include?("prism") - return # When the method was compiled by prism + if ENV['RBS_RUBY_PARSER'] == 'prism' + def block_from_ast_of(method) + iseq = RubyVM::InstructionSequence.of(method) + return unless iseq + return unless (path = iseq.absolute_path) + + node_id = iseq.to_a[4][:node_id] + result = Prism.parse_file(path) + def_node = result.value.breadth_first_search { |n| n.node_id == node_id && n.is_a?(::Prism::DefNode) } #: Prism::DefNode? + if def_node + prism = RB::Prism.new + prism.block_from_def(def_node, prism.build_body_info(def_node.body)) + end + end + else + include Prototype::RubyVMHelpers + + def untyped + @untyped ||= Types::Bases::Any.new(location: nil) end - if ast && ast.type == :SCOPE - block_from_body(ast) + def block_from_ast_of(method) + begin + ast = RubyVM::AbstractSyntaxTree.of(method) + rescue ArgumentError + return # When the method is defined in eval + rescue RuntimeError => error + raise unless error.message.include?("prism") + return # When the method was compiled by prism + end + + if ast && ast.type == :SCOPE + block_from_body(ast) + end end end end diff --git a/lib/rbs/prototype/runtime/value_object_generator.rb b/lib/rbs/prototype/runtime/value_object_generator.rb index ab984e1311..5e2de811a2 100644 --- a/lib/rbs/prototype/runtime/value_object_generator.rb +++ b/lib/rbs/prototype/runtime/value_object_generator.rb @@ -6,8 +6,6 @@ module RBS module Prototype class Runtime class ValueObjectBase - include Helpers - def initialize(target_class) @target_class = target_class end @@ -30,6 +28,10 @@ def build_decl private + def untyped + @untyped ||= Types::Bases::Any.new(location: nil) + end + # def self.members: () -> [ :foo, :bar ] # def members: () -> [ :foo, :bar ] def build_s_members diff --git a/sig/prototype/helpers.rbs b/sig/prototype/helpers.rbs index d8bbf54cc7..f306204e5c 100644 --- a/sig/prototype/helpers.rbs +++ b/sig/prototype/helpers.rbs @@ -1,9 +1,13 @@ module RBS module Prototype - module Helpers - type node = RubyVM::AbstractSyntaxTree::Node + module CommentParser + def build_comments_prism: (Array[Prism::Comment] comments, include_trailing: bool) -> Hash[Integer, AST::Comment] def parse_comments: (String, include_trailing: bool) -> Hash[Integer, AST::Comment] + end + + module RubyVMHelpers + type node = RubyVM::AbstractSyntaxTree::Node def block_from_body: (node) -> Types::Block? @@ -15,8 +19,6 @@ module RBS def keyword_hash?: (node) -> bool - # Returns a symbol if the node is a symbol literal node - # def symbol_literal_node?: (node) -> Symbol? def args_from_node: (node?) -> Array[untyped] diff --git a/sig/prototype/node_usage.rbs b/sig/prototype/node_usage.rbs index 6a72f8346b..caa13b343d 100644 --- a/sig/prototype/node_usage.rbs +++ b/sig/prototype/node_usage.rbs @@ -1,8 +1,8 @@ module RBS module Prototype class NodeUsage - include Helpers - + include RubyVMHelpers + type node = RubyVM::AbstractSyntaxTree::Node attr_reader node: node diff --git a/sig/prototype/rb.rbs b/sig/prototype/rb.rbs index e575fc4ee1..263329720c 100644 --- a/sig/prototype/rb.rbs +++ b/sig/prototype/rb.rbs @@ -1,7 +1,7 @@ module RBS module Prototype - class RB - include Helpers + module RB + def self.new: () -> Base class Context type method_kind = :singleton | :singleton_instance | :instance @@ -27,70 +27,132 @@ module RBS def self.initial: (?namespace: Namespace) -> Context end - type decl = AST::Declarations::t | AST::Members::t + class Base + include CommentParser - attr_reader source_decls: Array[decl] + type decl = AST::Declarations::t | AST::Members::t - def initialize: () -> void + attr_reader source_decls: Array[decl] - def decls: () -> Array[AST::Declarations::t] + def initialize: () -> void - def parse: (String) -> void + def decls: () -> Array[AST::Declarations::t] - def process: (untyped node, decls: Array[AST::Declarations::t | AST::Members::t], comments: Hash[Integer, AST::Comment], context: Context) -> void + def parse: (String) -> void - def process_children: (RubyVM::AbstractSyntaxTree::Node node, decls: Array[decl], comments: Hash[Integer, AST::Comment], context: Context) -> void + def types_to_union_type: (Array[Types::t] types) -> Types::t - # Returns a type name that represents the name of the constant. - # `node` must be _constant_ node, `CONST`, `COLON2`, or `COLON3` node. - # - def const_to_name!: (RubyVM::AbstractSyntaxTree::Node node, ?context: Context?) -> TypeName + def range_element_type: (Array[Types::t] types) -> Types::t - # Returns a type name that represents the name of the constant. - # `node` can be `SELF` for `extend self` pattern. - # - def const_to_name: (RubyVM::AbstractSyntaxTree::Node? node, context: Context) -> TypeName? + def untyped: () -> Types::Bases::Any - def literal_to_symbol: (RubyVM::AbstractSyntaxTree::Node node) -> Symbol? + @untyped: Types::Bases::Any - def function_type_from_body: (RubyVM::AbstractSyntaxTree::Node node, Symbol def_name) -> Types::Function + def private: () -> AST::Members::Private - def function_return_type_from_body: (RubyVM::AbstractSyntaxTree::Node node) -> Types::t + @private: AST::Members::Private? - def body_type: (RubyVM::AbstractSyntaxTree::Node node) -> Types::t + def public: () -> AST::Members::Public - def if_unless_type: (RubyVM::AbstractSyntaxTree::Node node) -> Types::t + @public: AST::Members::Public? - def block_type: (RubyVM::AbstractSyntaxTree::Node node) -> Types::t + def current_accessibility: (Array[decl] decls, ?Integer index) -> (AST::Members::Private | AST::Members::Public) - def literal_to_type: (RubyVM::AbstractSyntaxTree::Node node) -> Types::t + def remove_unnecessary_accessibility_methods!: (Array[decl]) -> void - def types_to_union_type: (Array[Types::t] types) -> Types::t + def is_accessibility?: (decl) -> bool - def range_element_type: (Array[Types::t] types) -> Types::t + def find_def_index_by_name: (Array[decl] decls, Symbol name) -> [Integer, AST::Members::MethodDefinition | AST::Members::AttrReader | AST::Members::AttrWriter]? - def param_type: (RubyVM::AbstractSyntaxTree::Node node, ?default: Types::Bases::Any) -> Types::t + def sort_members!: (Array[decl] decls) -> void + end + + class RubyVM < Base + include RubyVMHelpers + + type node = RubyVM::AbstractSyntaxTree::Node + + private + + def process: (node, decls: Array[Base::decl], comments: Hash[Integer, AST::Comment], context: Context) -> void + + def process_children: (node, decls: Array[Base::decl], comments: Hash[Integer, AST::Comment], context: Context) -> void + + def const_to_name!: (node, ?context: Context?) -> TypeName + + def const_to_name: (node?, context: Context) -> TypeName? - # backward compatible - alias node_type param_type + def literal_to_symbol: (node) -> Symbol? + + def function_type_from_body: (node, Symbol def_name) -> Types::Function + + def function_return_type_from_body: (node) -> Types::t + + def body_type: (node?) -> Types::t + + def if_unless_type: (node) -> Types::t + + def block_type: (node) -> Types::t + + def literal_to_type: (node) -> Types::t + + def param_type: (node, ?default: Types::Bases::Any) -> Types::t + end - def private: () -> AST::Members::Private + class Prism < Base + class NodeUsage + attr_reader conditional_nodes: Set[untyped] - @private: AST::Members::Private? + def initialize: (Prism::Node) -> void - def public: () -> AST::Members::Public + def each_conditional_node: () { (untyped) -> void } -> void + | () -> Enumerator[untyped, void] - @public: AST::Members::Public? + private - def current_accessibility: (Array[decl] decls, ?Integer index) -> (AST::Members::Private | AST::Members::Public) + def calculate: (untyped, conditional: bool) -> void - def remove_unnecessary_accessibility_methods!: (Array[decl]) -> void + def calculate_statements: (untyped, conditional: bool) -> void + end - def is_accessibility?: (decl) -> bool + class BodyInfo + attr_reader yields: Array[Prism::YieldNode] + attr_reader has_block_given: bool + attr_reader returns: Array[Prism::ReturnNode]? - def find_def_index_by_name: (Array[decl] decls, Symbol name) -> [Integer, AST::Members::MethodDefinition | AST::Members::AttrReader | AST::Members::AttrWriter]? + def initialize: (yields: Array[Prism::YieldNode], has_block_given: bool, returns: Array[Prism::ReturnNode]?) -> void + end - def sort_members!: (Array[decl] decls) -> void + def build_body_info: (Prism::Node?) -> BodyInfo + + def block_from_def: (Prism::DefNode, BodyInfo) -> Types::Block? + + private + + def process: (untyped, decls: Array[Base::decl], comments: Hash[Integer, AST::Comment], context: Context) -> void + + def const_to_name!: (untyped, ?context: Context?) -> TypeName + + def const_to_name: (untyped, context: Context) -> TypeName? + + def symbol_value: (untyped) -> Symbol? + + def function_type: (Prism::DefNode, BodyInfo) -> Types::Function + + def keyword_hash?: (untyped) -> bool + + def return_type_from_body: (untyped, returns: Array[Prism::ReturnNode]?) -> Types::t + + def if_unless_type: (untyped) -> Types::t + + def statements_type: (untyped, returns: Array[Prism::ReturnNode]?) -> Types::t + + def literal_to_type: (untyped) -> Types::t + + def hash_type: (Array[untyped]) -> Types::t + + def param_type: (untyped, ?default: Types::Bases::Any) -> Types::t + end end end end diff --git a/sig/prototype/rbi.rbs b/sig/prototype/rbi.rbs index d8b37e9e83..02b7cad4a1 100644 --- a/sig/prototype/rbi.rbs +++ b/sig/prototype/rbi.rbs @@ -1,75 +1,108 @@ module RBS module Prototype - class RBI - include Helpers - - attr_reader decls: Array[AST::Declarations::t] + module RBI + def self.new: () -> Base - type module_decl = AST::Declarations::Class | AST::Declarations::Module + class Base + include CommentParser - # A stack representing the module nesting structure in the Ruby code - attr_reader modules: Array[module_decl] + type module_decl = AST::Declarations::Class | AST::Declarations::Module - # Last subsequent `sig` calls - attr_reader last_sig: Array[RubyVM::AbstractSyntaxTree::Node]? + attr_reader decls: Array[AST::Declarations::t] - def initialize: () -> void + attr_reader modules: Array[module_decl] - def parse: (String) -> void + attr_reader last_sig: Array[untyped]? - def nested_name: (RubyVM::AbstractSyntaxTree::Node name) -> TypeName + def initialize: () -> void - def current_namespace: () -> Namespace + def parse: (String) -> void - def push_class: ( - RubyVM::AbstractSyntaxTree::Node name, - RubyVM::AbstractSyntaxTree::Node super_class, - comment: AST::Comment? - ) { () -> void } -> void + def nested_name: (untyped name) -> TypeName - def push_module: (RubyVM::AbstractSyntaxTree::Node name, comment: AST::Comment?) { () -> void } -> void + def current_namespace: () -> Namespace - # The inner most module/class definition, returns `nil` on toplevel - def current_module: () -> module_decl? + def push_class: (untyped name, untyped super_class, comment: AST::Comment?) { () -> void } -> void - # The inner most module/class definition, raises on toplevel - def current_module!: () -> module_decl + def push_module: (untyped name, comment: AST::Comment?) { () -> void } -> void - # Put a `sig` call to current list. - def push_sig: (RubyVM::AbstractSyntaxTree::Node node) -> void + def current_module: () -> module_decl? - # Clear the `sig` call list - def pop_sig: () -> Array[RubyVM::AbstractSyntaxTree::Node]? + def current_module!: () -> module_decl - def join_comments: (Array[RubyVM::AbstractSyntaxTree::Node] nodes, Hash[Integer, AST::Comment] comments) -> AST::Comment + def push_sig: (untyped node) -> void - def process: (RubyVM::AbstractSyntaxTree::Node node, comments: Hash[Integer, AST::Comment], ?outer: Array[RubyVM::AbstractSyntaxTree::Node]) -> void + def pop_sig: () -> Array[untyped]? - def method_type: (RubyVM::AbstractSyntaxTree::Node? args_node, RubyVM::AbstractSyntaxTree::Node? type_node, variables: Array[AST::TypeParam], overloads: Integer) -> MethodType? + def const_to_name: (untyped node) -> TypeName + end - def parse_params: (RubyVM::AbstractSyntaxTree::Node args_node, RubyVM::AbstractSyntaxTree::Node args, MethodType method_type, variables: Array[AST::TypeParam], overloads: Integer) -> MethodType + class RubyVM < Base + include RubyVMHelpers - def type_of: (RubyVM::AbstractSyntaxTree::Node type_node, variables: Array[AST::TypeParam]) -> Types::t + type node = RubyVM::AbstractSyntaxTree::Node - def type_of0: (RubyVM::AbstractSyntaxTree::Node type_node, variables: Array[AST::TypeParam]) -> Types::t + def join_comments: (Array[node], Hash[Integer, AST::Comment]) -> AST::Comment - def proc_type?: (RubyVM::AbstractSyntaxTree::Node type_node) -> bool + def process: (node, comments: Hash[Integer, AST::Comment], ?outer: Array[node]) -> void - def call_node?: (RubyVM::AbstractSyntaxTree::Node node, name: Symbol, ?receiver: ^(RubyVM::AbstractSyntaxTree::Node) -> bool, ?args: ^(RubyVM::AbstractSyntaxTree::Node) -> bool) -> bool + def method_type: (node?, node?, variables: Array[AST::TypeParam], overloads: Integer) -> MethodType? - # Receives a constant node and returns `TypeName` instance - def const_to_name: (RubyVM::AbstractSyntaxTree::Node node) -> TypeName + def parse_params: (node, node, MethodType, variables: Array[AST::TypeParam], overloads: Integer) -> MethodType - # Receives `:ARRAY` or `:LIST` node and yields the child nodes. - def each_arg: (RubyVM::AbstractSyntaxTree::Node array) { (RubyVM::AbstractSyntaxTree::Node) -> void } -> void - | (RubyVM::AbstractSyntaxTree::Node array) -> Enumerator[RubyVM::AbstractSyntaxTree::Node, void] + def type_of: (node, variables: Array[AST::TypeParam]) -> Types::t - # Receives node and yields the child nodes. - def each_child: (RubyVM::AbstractSyntaxTree::Node node) { (RubyVM::AbstractSyntaxTree::Node) -> void } -> void - | (RubyVM::AbstractSyntaxTree::Node node) -> Enumerator[RubyVM::AbstractSyntaxTree::Node, void] + def type_of0: (node, variables: Array[AST::TypeParam]) -> Types::t - # Receives a keyword `:HASH` node and returns hash instance. - def node_to_hash: (RubyVM::AbstractSyntaxTree::Node node) -> Hash[Symbol, RubyVM::AbstractSyntaxTree::Node]? + def proc_type?: (node) -> bool + + def call_node?: (node, name: Symbol, ?receiver: ^(node) -> bool, ?args: ^(node) -> bool) -> bool + + def each_arg: (node?) { (node) -> void } -> void + | (node?) -> Enumerator[node, void] + + def each_child: (node) { (node) -> void } -> void + + def node_to_hash: (node?) -> Hash[Symbol, node]? + + def untyped: () -> Types::Bases::Any + + @untyped: Types::Bases::Any + end + + class Prism < Base + def process: (untyped, ?outer: Array[untyped], comments: Hash[Integer, AST::Comment]) -> void + + private + + def handle_fcall: (untyped, outer: Array[untyped], comments: Hash[Integer, AST::Comment]) -> void + + def handle_sig: (untyped) -> void + + def join_comments: (Array[untyped], Hash[Integer, AST::Comment]) -> AST::Comment + + def method_type: (Prism::DefNode?, untyped, variables: Array[AST::TypeParam], overloads: Integer) -> MethodType? + + def walk_sig_chain: (Prism::DefNode?, untyped, MethodType, variables: Array[AST::TypeParam], overloads: Integer) -> MethodType + + def parse_params: (Prism::DefNode, Array[untyped], MethodType, variables: Array[AST::TypeParam], overloads: Integer) -> MethodType + + def args_to_hash: (untyped, variables: Array[AST::TypeParam]) -> Hash[Symbol, Types::t] + + def type_of: (untyped, variables: Array[AST::TypeParam]) -> Types::t + + def type_of0: (untyped, variables: Array[AST::TypeParam]) -> Types::t + + def t_call?: (untyped) -> bool + + def handle_t_call: (untyped, variables: Array[AST::TypeParam]) -> Types::t + + def proc_type?: (untyped) -> bool + + def handle_cdecl: (Prism::ConstantWriteNode) -> void + + def symbol_value: (untyped) -> Symbol? + end end end end diff --git a/sig/prototype/runtime.rbs b/sig/prototype/runtime.rbs index 0221518be7..9ecb057073 100644 --- a/sig/prototype/runtime.rbs +++ b/sig/prototype/runtime.rbs @@ -28,7 +28,7 @@ module RBS end class ValueObjectBase - include Helpers + include Runtime::Helpers # @target_class stores the singleton object of `Data` or `Struct` subclass @target_class: untyped @@ -37,6 +37,10 @@ module RBS private + def untyped: () -> Types::Bases::Any + + @untyped: Types::Bases::Any + def build_member_accessors: (untyped ast_members_class) -> untyped def build_s_members: () -> Array[AST::Members::MethodDefinition] @@ -111,7 +115,8 @@ module RBS @todo_object: Todo? - include Helpers + include RubyVMHelpers + include Runtime::Helpers attr_reader patterns: Array[String] @@ -172,8 +177,6 @@ module RBS def block_from_ast_of: (UnboundMethod method) -> Types::Block? - def block_from_body: (RubyVM::AbstractSyntaxTree::Node) -> Types::Block? - def can_alias?: (Module, UnboundMethod) -> bool def type_params: (Module) -> Array[AST::TypeParam] diff --git a/steep/patch.rbs b/steep/patch.rbs index 4bbacff423..a875075de1 100644 --- a/steep/patch.rbs +++ b/steep/patch.rbs @@ -7,3 +7,9 @@ end class Proc def ruby2_keywords: () -> self end + +class RubyVM + class InstructionSequence + def self.of: (Method | UnboundMethod | Proc) -> RubyVM::InstructionSequence? + end +end diff --git a/test/rbs/rb_prototype_test.rb b/test/rbs/rb_prototype_test.rb index 2722347ece..88fd6f58d0 100644 --- a/test/rbs/rb_prototype_test.rb +++ b/test/rbs/rb_prototype_test.rb @@ -1021,7 +1021,7 @@ def message: (untyped message) -> untyped end def test_literal_to_type - parser = RBS::Prototype::RB.new + parser = RBS::Prototype::RB::RubyVM.new [ [%{"abc"}, %{"abc"}], [%{:abc}, %{:abc}], @@ -1038,7 +1038,7 @@ def test_literal_to_type end def test_const_to_name - parser = RBS::Prototype::RB.new + parser = RBS::Prototype::RB::RubyVM.new [ ["self", RBS::TypeName.parse("::Foo")], ["Bar", RBS::TypeName.parse("Bar")],