summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--parse.y12
-rw-r--r--prism/prism.c14
-rw-r--r--test/prism/errors/dont_allow_return_inside_sclass_body.txt3
-rw-r--r--test/ruby/test_class.rb1
4 files changed, 25 insertions, 5 deletions
diff --git a/parse.y b/parse.y
index 37f30a4881..cd76194df8 100644
--- a/parse.y
+++ b/parse.y
@@ -318,6 +318,7 @@ struct lex_context {
unsigned int in_class: 1;
BITFIELD(enum rb_parser_shareability, shareable_constant_value, 2);
BITFIELD(enum rescue_context, in_rescue, 2);
+ unsigned int cant_return: 1;
};
typedef struct RNode_DEF_TEMP rb_node_def_temp_t;
@@ -1696,12 +1697,17 @@ endless_method_name(struct parser_params *p, ID mid, const YYLTYPE *loc)
#define begin_definition(k, loc_beg, loc_end) \
do { \
if (!(p->ctxt.in_class = (k)[0] != 0)) { \
+ /* singleton class */ \
+ p->ctxt.cant_return = !p->ctxt.in_def; \
p->ctxt.in_def = 0; \
} \
else if (p->ctxt.in_def) { \
YYLTYPE loc = code_loc_gen(loc_beg, loc_end); \
yyerror1(&loc, k " definition in method body"); \
} \
+ else { \
+ p->ctxt.cant_return = 1; \
+ } \
local_push(p, 0); \
} while (0)
@@ -3400,6 +3406,7 @@ def_name : fname
local_push(p, 0);
p->ctxt.in_def = 1;
p->ctxt.in_rescue = before_rescue;
+ p->ctxt.cant_return = 0;
$$ = $1;
}
;
@@ -4628,6 +4635,7 @@ primary : literal
/*% ripper: class!($:cpath, $:superclass, $:bodystmt) %*/
local_pop(p);
p->ctxt.in_class = $k_class.in_class;
+ p->ctxt.cant_return = $k_class.cant_return;
p->ctxt.shareable_constant_value = $k_class.shareable_constant_value;
}
| k_class tLSHFT expr_value
@@ -4646,6 +4654,7 @@ primary : literal
local_pop(p);
p->ctxt.in_def = $k_class.in_def;
p->ctxt.in_class = $k_class.in_class;
+ p->ctxt.cant_return = $k_class.cant_return;
p->ctxt.shareable_constant_value = $k_class.shareable_constant_value;
}
| k_module cpath
@@ -4662,6 +4671,7 @@ primary : literal
/*% ripper: module!($:cpath, $:bodystmt) %*/
local_pop(p);
p->ctxt.in_class = $k_module.in_class;
+ p->ctxt.cant_return = $k_module.cant_return;
p->ctxt.shareable_constant_value = $k_module.shareable_constant_value;
}
| defn_head[head]
@@ -4890,7 +4900,7 @@ k_end : keyword_end
k_return : keyword_return
{
- if (p->ctxt.in_class && !p->ctxt.in_def && !dyna_in_block(p))
+ if (p->ctxt.cant_return && !dyna_in_block(p))
yyerror1(&@1, "Invalid return in class/module body");
}
;
diff --git a/prism/prism.c b/prism/prism.c
index 9168c520b7..a4d314453a 100644
--- a/prism/prism.c
+++ b/prism/prism.c
@@ -15342,6 +15342,7 @@ parse_arguments_list(pm_parser_t *parser, pm_arguments_t *arguments, bool accept
*/
static void
parse_return(pm_parser_t *parser, pm_node_t *node) {
+ bool in_sclass = false;
for (pm_context_node_t *context_node = parser->current_context; context_node != NULL; context_node = context_node->prev) {
switch (context_node->context) {
case PM_CONTEXT_BEGIN_ELSE:
@@ -15366,10 +15367,6 @@ parse_return(pm_parser_t *parser, pm_node_t *node) {
case PM_CONTEXT_PREDICATE:
case PM_CONTEXT_PREEXE:
case PM_CONTEXT_RESCUE_MODIFIER:
- case PM_CONTEXT_SCLASS_ELSE:
- case PM_CONTEXT_SCLASS_ENSURE:
- case PM_CONTEXT_SCLASS_RESCUE:
- case PM_CONTEXT_SCLASS:
case PM_CONTEXT_TERNARY:
case PM_CONTEXT_UNLESS:
case PM_CONTEXT_UNTIL:
@@ -15377,6 +15374,12 @@ parse_return(pm_parser_t *parser, pm_node_t *node) {
// Keep iterating up the lists of contexts, because returns can
// see through these.
continue;
+ case PM_CONTEXT_SCLASS_ELSE:
+ case PM_CONTEXT_SCLASS_ENSURE:
+ case PM_CONTEXT_SCLASS_RESCUE:
+ case PM_CONTEXT_SCLASS:
+ in_sclass = true;
+ continue;
case PM_CONTEXT_CLASS_ELSE:
case PM_CONTEXT_CLASS_ENSURE:
case PM_CONTEXT_CLASS_RESCUE:
@@ -15411,6 +15414,9 @@ parse_return(pm_parser_t *parser, pm_node_t *node) {
break;
}
}
+ if (in_sclass) {
+ pm_parser_err_node(parser, node, PM_ERR_RETURN_INVALID);
+ }
}
/**
diff --git a/test/prism/errors/dont_allow_return_inside_sclass_body.txt b/test/prism/errors/dont_allow_return_inside_sclass_body.txt
new file mode 100644
index 0000000000..c29fe01728
--- /dev/null
+++ b/test/prism/errors/dont_allow_return_inside_sclass_body.txt
@@ -0,0 +1,3 @@
+class << A; return; end
+ ^~~~~~ Invalid return in class/module body
+
diff --git a/test/ruby/test_class.rb b/test/ruby/test_class.rb
index 710b8a6f7b..38a6e9eb9f 100644
--- a/test/ruby/test_class.rb
+++ b/test/ruby/test_class.rb
@@ -316,6 +316,7 @@ class TestClass < Test::Unit::TestCase
def test_invalid_return_from_class_definition
assert_syntax_error("class C; return; end", /Invalid return/)
+ assert_syntax_error("class << Object; return; end", /Invalid return/)
end
def test_invalid_yield_from_class_definition