From 68a065935a857c818bfcc30331af2049834a34b0 2021-05-21 20:33:08 From: MH Date: 2021-05-21 20:33:08 Subject: [PATCH] prepare for casting expressions --- diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index 674196216068f1b7af7f34f77e500b996290b6cc..55adefdcc4138913b32bcf0e53a368c99198ba3a 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -154,6 +154,7 @@ define_new_ast_id!(IndexingExpressionId, ExpressionId, index(IndexingExpression, define_new_ast_id!(SlicingExpressionId, ExpressionId, index(SlicingExpression, Expression::Slicing, expressions), alloc(alloc_slicing_expression)); define_new_ast_id!(SelectExpressionId, ExpressionId, index(SelectExpression, Expression::Select, expressions), alloc(alloc_select_expression)); define_new_ast_id!(LiteralExpressionId, ExpressionId, index(LiteralExpression, Expression::Literal, expressions), alloc(alloc_literal_expression)); +define_new_ast_id!(CastExpressionId, ExpressionId, index(CastExpression, Expression::Cast, expressions), alloc(alloc_cast_expression)); define_new_ast_id!(CallExpressionId, ExpressionId, index(CallExpression, Expression::Call, expressions), alloc(alloc_call_expression)); define_new_ast_id!(VariableExpressionId, ExpressionId, index(VariableExpression, Expression::Variable, expressions), alloc(alloc_variable_expression)); @@ -1404,6 +1405,7 @@ pub enum Expression { Slicing(SlicingExpression), Select(SelectExpression), Literal(LiteralExpression), + Cast(CastExpression), Call(CallExpression), Variable(VariableExpression), } @@ -1486,6 +1488,7 @@ impl Expression { Expression::Slicing(expr) => expr.span, Expression::Select(expr) => expr.span, Expression::Literal(expr) => expr.span, + Expression::Cast(expr) => expr.span, Expression::Call(expr) => expr.span, Expression::Variable(expr) => expr.identifier.span, } @@ -1502,6 +1505,7 @@ impl Expression { Expression::Slicing(expr) => &expr.parent, Expression::Select(expr) => &expr.parent, Expression::Literal(expr) => &expr.parent, + Expression::Cast(expr) => &expr.parent, Expression::Call(expr) => &expr.parent, Expression::Variable(expr) => &expr.parent, } @@ -1526,6 +1530,7 @@ impl Expression { Expression::Slicing(expr) => expr.parent = parent, Expression::Select(expr) => expr.parent = parent, Expression::Literal(expr) => expr.parent = parent, + Expression::Cast(expr) => expr.parent = parent, Expression::Call(expr) => expr.parent = parent, Expression::Variable(expr) => expr.parent = parent, } @@ -1542,6 +1547,7 @@ impl Expression { Expression::Slicing(expr) => expr.unique_id_in_definition, Expression::Select(expr) => expr.unique_id_in_definition, Expression::Literal(expr) => expr.unique_id_in_definition, + Expression::Cast(expr) => expr.unique_id_in_definition, Expression::Call(expr) => expr.unique_id_in_definition, Expression::Variable(expr) => expr.unique_id_in_definition, } @@ -1698,6 +1704,18 @@ pub struct SelectExpression { pub unique_id_in_definition: i32, } +#[derive(Debug, Clone)] +pub struct CastExpression { + pub this: CastExpressionId, + // Parsing + pub span: InputSpan, // of the "cast" keyword, + pub to_type: ParserType, + pub subject: ExpressionId, + // Validator/linker + pub parent: ExpressionParent, + pub unique_id_in_definition: i32, +} + #[derive(Debug, Clone)] pub struct CallExpression { pub this: CallExpressionId, diff --git a/src/protocol/ast_printer.rs b/src/protocol/ast_printer.rs index a788c9be9fb03ca95f022fc9e4c0b424814e366a..1f1a37a352675b6071b3d04b9a3e494270d0c9ec 100644 --- a/src/protocol/ast_printer.rs +++ b/src/protocol/ast_printer.rs @@ -704,6 +704,9 @@ impl ASTWriter { self.kv(indent2).with_s_key("Parent") .with_custom_val(|v| write_expression_parent(v, &expr.parent)); }, + Expression::Cast(expr) => { + todo!("print casting expression") + } Expression::Call(expr) => { self.kv(indent).with_id(PREFIX_CALL_EXPR_ID, expr.this.0.index) .with_s_key("CallExpr"); diff --git a/src/protocol/eval/executor.rs b/src/protocol/eval/executor.rs index 46f86751474def60f6ca98baf04347136ca27d46..dbca1094afe3719572c8384fd5b8c8265a24ddbf 100644 --- a/src/protocol/eval/executor.rs +++ b/src/protocol/eval/executor.rs @@ -152,6 +152,9 @@ impl Frame { } } }, + Expression::Cast(_) => { + todo!("casting expression"); + } Expression::Call(expr) => { for arg_expr_id in &expr.arguments { self.expr_stack.push_back(ExprInstruction::PushValToFront); @@ -488,6 +491,9 @@ impl Prompt { cur_frame.expr_values.push_back(value); }, + Expression::Cast(_) => { + todo!("casting expression evaluation"); + } Expression::Call(expr) => { // Push a new frame. Note that all expressions have // been pushed to the front, so they're in the order diff --git a/src/protocol/parser/depth_visitor.rs b/src/protocol/parser/depth_visitor.rs index 90505722b27003ecc0317a6beff8d72c8b6cff13..b29f0c49937e61d86b194aa8f78d80c0fa466460 100644 --- a/src/protocol/parser/depth_visitor.rs +++ b/src/protocol/parser/depth_visitor.rs @@ -166,6 +166,10 @@ pub(crate) trait Visitor: Sized { fn visit_select_expression(&mut self, h: &mut Heap, expr: SelectExpressionId) -> VisitorResult { recursive_select_expression(self, h, expr) } + fn visit_cast_expression(&mut self, h: &mut Heap, expr: CastExpressionId) -> VisitorResult { + let subject = h[expr].subject; + self.visit_expression(h, subject) + } fn visit_call_expression(&mut self, h: &mut Heap, expr: CallExpressionId) -> VisitorResult { recursive_call_expression(self, h, expr) } @@ -400,6 +404,7 @@ fn recursive_expression( Expression::Slicing(expr) => this.visit_slicing_expression(h, expr.this), Expression::Select(expr) => this.visit_select_expression(h, expr.this), Expression::Literal(expr) => this.visit_constant_expression(h, expr.this), + Expression::Cast(expr) => this.visit_cast_expression(h, expr.this), Expression::Call(expr) => this.visit_call_expression(h, expr.this), Expression::Variable(expr) => this.visit_variable_expression(h, expr.this), } diff --git a/src/protocol/parser/pass_definitions.rs b/src/protocol/parser/pass_definitions.rs index f9f4826427962b1e0e90dbe34b7a94ea5c089dc0..e6d6204d97a66a4f9a0493382f8c15fd4887cf04 100644 --- a/src/protocol/parser/pass_definitions.rs +++ b/src/protocol/parser/pass_definitions.rs @@ -704,7 +704,7 @@ impl PassDefinitions { let poly_vars = ctx.heap[definition_id].poly_vars(); consume_parser_type( &module.source, iter, &ctx.symbols, &ctx.heap, - &poly_vars, SymbolScope::Module(module.root_id), definition_id, + poly_vars, SymbolScope::Module(module.root_id), definition_id, true, 1 )? } else { @@ -1451,13 +1451,45 @@ impl PassDefinitions { _ => unreachable!(), }; - ctx.heap.alloc_literal_expression(|this| LiteralExpression{ + ctx.heap.alloc_literal_expression(|this| LiteralExpression { this, span: ident_span, value, parent: ExpressionParent::None, unique_id_in_definition: -1, }).upcast() + } else if ident_text == KW_CAST { + // Casting expression + iter.consume(); + let to_type = if Some(TokenKind::OpenAngle) == iter.next() { + iter.consume(); + let definition_id = self.cur_definition; + let poly_vars = ctx.heap[definition_id].poly_vars(); + consume_parser_type( + &module.source, iter, &ctx.symbols, &ctx.heap, + poly_vars, SymbolScope::Module(module.root_id), definition_id, + true, 1 + )? + } else { + // Automatic casting with inferred target type + ParserType{ elements: vec![ParserTypeElement{ + full_span: ident_span, // TODO: @Span fix + variant: ParserTypeVariant::Inferred, + }]} + }; + + consume_token(&module.source, iter, TokenKind::OpenParen)?; + let subject = self.consume_expression(module, iter, ctx)?; + consume_token(&module.source, iter, TokenKind::CloseParen)?; + + ctx.heap.alloc_cast_expression(|this| CastExpression{ + this, + span: ident_span, + to_type, + subject, + parent: ExpressionParent::None, + unique_id_in_definition: -1, + }).upcast() } else { // Not a builtin literal, but also not a known type. So we // assume it is a variable expression. Although if we do, diff --git a/src/protocol/parser/pass_typing.rs b/src/protocol/parser/pass_typing.rs index 32eea7db518c49688743324867849fc39fedb058..55968677a316d9922b8c627888ec84b7a9b22a14 100644 --- a/src/protocol/parser/pass_typing.rs +++ b/src/protocol/parser/pass_typing.rs @@ -1317,6 +1317,18 @@ impl Visitor2 for PassTyping { self.progress_literal_expr(ctx, id) } + fn visit_cast_expr(&mut self, ctx: &mut Ctx, id: CastExpressionId) -> VisitorResult { + let upcast_id = id.upcast(); + self.insert_initial_expr_inference_type(ctx, upcast_id)?; + + let cast_expr = &ctx.heap[id]; + let subject_expr_id = cast_expr.subject; + + self.visit_expr(ctx, subject_expr_id)?; + + self.progress_cast_expr(ctx, id) + } + fn visit_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> VisitorResult { let upcast_id = id.upcast(); self.insert_initial_expr_inference_type(ctx, upcast_id)?; @@ -1329,7 +1341,7 @@ impl Visitor2 for PassTyping { self.expr_types[call_expr.unique_id_in_definition as usize].field_or_monomorph_idx = 0; // Visit all arguments - for arg_expr_id in call_expr.arguments.clone() { + for arg_expr_id in call_expr.arguments.clone() { // TODO: @Performance self.visit_expr(ctx, arg_expr_id)?; } @@ -1540,6 +1552,10 @@ impl PassTyping { let id = expr.this; self.progress_literal_expr(ctx, id) }, + Expression::Cast(expr) => { + let id = expr.this; + self.progress_cast_expr(ctx, id) + }, Expression::Call(expr) => { let id = expr.this; self.progress_call_expr(ctx, id) @@ -2289,6 +2305,17 @@ impl PassTyping { Ok(()) } + fn progress_cast_expr(&mut self, ctx: &mut Ctx, id: CastExpressionId) -> Result<(), ParseError> { + let upcast_id = id.upcast(); + let expr = &ctx.heap[id]; + let expr_idx = expr.unique_id_in_definition; + + // The cast expression acts like a blocker for two-way inference. The + // only thing we can do is wait until both the input and the output + // TODO: Continue here + Ok(()) + } + // TODO: @cleanup, see how this can be cleaned up once I implement // polymorphic struct/enum/union literals. These likely follow the same // pattern as here. diff --git a/src/protocol/parser/pass_validation_linking.rs b/src/protocol/parser/pass_validation_linking.rs index 0efaf3dc8d3194cb3b925fe7812d49682fa271dc..957317d97b323bfb8157f81184743752e4703143 100644 --- a/src/protocol/parser/pass_validation_linking.rs +++ b/src/protocol/parser/pass_validation_linking.rs @@ -727,6 +727,30 @@ impl Visitor2 for PassValidationLinking { Ok(()) } + fn visit_cast_expr(&mut self, ctx: &mut Ctx, id: CastExpressionId) -> VisitorResult { + let cast_expr = &mut ctx.heap[id]; + + if let Some(span) = self.must_be_assignable { + return Err(ParseError::new_error_str_at_span( + &ctx.module.source, span, "cannot assign to the result from a cast expression" + )) + } + + let upcast_id = id.upcast(); + let old_expr_parent = self.expr_parent; + cast_expr.parent = old_expr_parent; + cast_expr.unique_id_in_definition = self.next_expr_index; + self.next_expr_index += 1; + + // Recurse into the thing that we're casting + self.expr_parent = ExpressionParent::Expression(upcast_id, 0); + let subject_id = cast_expr.subject; + self.visit_expr(ctx, subject_id)?; + self.expr_parent = old_expr_parent; + + Ok(()) + } + fn visit_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> VisitorResult { let call_expr = &mut ctx.heap[id]; diff --git a/src/protocol/parser/token_parsing.rs b/src/protocol/parser/token_parsing.rs index 419f3cb0d86860d563777aa9b4bbb9743bac367a..670cde33320a149ac89187fad3b5533e8cee9d5e 100644 --- a/src/protocol/parser/token_parsing.rs +++ b/src/protocol/parser/token_parsing.rs @@ -26,7 +26,8 @@ pub(crate) const KW_LIT_TRUE: &'static [u8] = b"true"; pub(crate) const KW_LIT_FALSE: &'static [u8] = b"false"; pub(crate) const KW_LIT_NULL: &'static [u8] = b"null"; -// Keywords - functions +// Keywords - function(like)s +pub(crate) const KW_CAST: &'static [u8] = b"cast"; pub(crate) const KW_FUNC_GET: &'static [u8] = b"get"; pub(crate) const KW_FUNC_PUT: &'static [u8] = b"put"; pub(crate) const KW_FUNC_FIRES: &'static [u8] = b"fires"; @@ -521,7 +522,7 @@ fn is_reserved_statement_keyword(text: &[u8]) -> bool { fn is_reserved_expression_keyword(text: &[u8]) -> bool { match text { - KW_LET | + KW_LET | KW_CAST | KW_LIT_TRUE | KW_LIT_FALSE | KW_LIT_NULL | KW_FUNC_GET | KW_FUNC_PUT | KW_FUNC_FIRES | KW_FUNC_CREATE | KW_FUNC_ASSERT | KW_FUNC_LENGTH => true, _ => false, diff --git a/src/protocol/parser/visitor.rs b/src/protocol/parser/visitor.rs index f6cfec81afb14332309fb4d0cb46e686a8c806ea..b758cf493f60aeeee202fbc7ffd569960c54b0ca 100644 --- a/src/protocol/parser/visitor.rs +++ b/src/protocol/parser/visitor.rs @@ -205,6 +205,10 @@ pub(crate) trait Visitor2 { let this = expr.this; self.visit_literal_expr(ctx, this) } + Expression::Cast(expr) => { + let this = expr.this; + self.visit_cast_expr(ctx, this) + } Expression::Call(expr) => { let this = expr.this; self.visit_call_expr(ctx, this) @@ -225,6 +229,7 @@ pub(crate) trait Visitor2 { fn visit_slicing_expr(&mut self, _ctx: &mut Ctx, _id: SlicingExpressionId) -> VisitorResult { Ok(()) } fn visit_select_expr(&mut self, _ctx: &mut Ctx, _id: SelectExpressionId) -> VisitorResult { Ok(()) } fn visit_literal_expr(&mut self, _ctx: &mut Ctx, _id: LiteralExpressionId) -> VisitorResult { Ok(()) } + fn visit_cast_expr(&mut self, _ctx: &mut Ctx, _id: CastExpressionId) -> VisitorResult { Ok(()) } fn visit_call_expr(&mut self, _ctx: &mut Ctx, _id: CallExpressionId) -> VisitorResult { Ok(()) } fn visit_variable_expr(&mut self, _ctx: &mut Ctx, _id: VariableExpressionId) -> VisitorResult { Ok(()) } } \ No newline at end of file diff --git a/src/protocol/tests/eval_silly.rs b/src/protocol/tests/eval_silly.rs index 004afb78eb4b7d1568271f57a7d2e7c4754716d9..ef7478190e07155df12dfaa841bb57bfe359848f 100644 --- a/src/protocol/tests/eval_silly.rs +++ b/src/protocol/tests/eval_silly.rs @@ -184,31 +184,32 @@ fn test_struct_fields() { fn test_field_selection_polymorphism() { // Bit silly, but just to be sure Tester::new_single_source_expect_ok("struct field shuffles", " -struct VecXYZ { T x, T y, T z } -struct VecYZX { T y, T z, T x } -struct VecZXY { T z, T x, T y } -func modify_x(T input) -> T { - input.x = 1337; - return input; -} + struct VecXYZ { T x, T y, T z } + struct VecYZX { T y, T z, T x } + struct VecZXY { T z, T x, T y } -func foo() -> bool { - auto xyz = VecXYZ{ x: 1, y: 2, z: 3 }; - auto yzx = VecYZX{ y: 2, z: 3, x: 1 }; - auto zxy = VecZXY{ x: 1, y: 2, z: 3 }; - - auto mod_xyz = modify_x(xyz); - auto mod_yzx = modify_x(yzx); - auto mod_zxy = modify_x(zxy); - - return - xyz.x == 1 && xyz.y == 2 && xyz.z == 3 && - yzx.x == 1 && yzx.y == 2 && yzx.z == 3 && - zxy.x == 1 && zxy.y == 2 && zxy.z == 3 && - mod_xyz.x == 1337 && mod_xyz.y == 2 && mod_xyz.z == 3 && - mod_yzx.x == 1337 && mod_yzx.y == 2 && mod_yzx.z == 3 && - mod_zxy.x == 1337 && mod_zxy.y == 2 && mod_zxy.z == 3; -} + func modify_x(T input) -> T { + input.x = 1337; + return input; + } + + func foo() -> bool { + auto xyz = VecXYZ{ x: 1, y: 2, z: 3 }; + auto yzx = VecYZX{ y: 2, z: 3, x: 1 }; + auto zxy = VecZXY{ x: 1, y: 2, z: 3 }; // different initialization order + + auto mod_xyz = modify_x(xyz); + auto mod_yzx = modify_x(yzx); + auto mod_zxy = modify_x(zxy); + + return + xyz.x == 1 && xyz.y == 2 && xyz.z == 3 && + yzx.x == 1 && yzx.y == 2 && yzx.z == 3 && + zxy.x == 1 && zxy.y == 2 && zxy.z == 3 && + mod_xyz.x == 1337 && mod_xyz.y == 2 && mod_xyz.z == 3 && + mod_yzx.x == 1337 && mod_yzx.y == 2 && mod_yzx.z == 3 && + mod_zxy.x == 1337 && mod_zxy.y == 2 && mod_zxy.z == 3; + } ").for_function("foo", |f| { f.call_ok(Some(Value::Bool(true))); }); diff --git a/src/protocol/tests/parser_inference.rs b/src/protocol/tests/parser_inference.rs index f64a38865efca7c03e626aa2a1e55d8fc973b030..bc374563dba2201bc6152c215cf42f1c536fef6b 100644 --- a/src/protocol/tests/parser_inference.rs +++ b/src/protocol/tests/parser_inference.rs @@ -330,48 +330,48 @@ fn test_failed_variable_inference() { #[test] fn test_failed_polymorph_inference() { - // Tester::new_single_source_expect_err( - // "function call inference mismatch", - // " - // func poly(T a, T b) -> s32 { return 0; } - // func call() -> s32 { - // s8 first_arg = 5; - // s64 second_arg = 2; - // return poly(first_arg, second_arg); - // } - // " - // ).error(|e| { e - // .assert_num(3) - // .assert_ctx_has(0, "poly(first_arg, second_arg)") - // .assert_occurs_at(0, "poly") - // .assert_msg_has(0, "Conflicting type for polymorphic variable 'T'") - // .assert_occurs_at(1, "second_arg") - // .assert_msg_has(1, "inferred it to 's64'") - // .assert_occurs_at(2, "first_arg") - // .assert_msg_has(2, "inferred it to 's8'"); - // }); - // - // Tester::new_single_source_expect_err( - // "struct literal inference mismatch", - // " - // struct Pair{ T first, T second } - // func call() -> s32 { - // s8 first_arg = 5; - // s64 second_arg = 2; - // auto pair = Pair{ first: first_arg, second: second_arg }; - // return 3; - // } - // " - // ).error(|e| { e - // .assert_num(3) - // .assert_ctx_has(0, "Pair{ first: first_arg, second: second_arg }") - // .assert_occurs_at(0, "Pair{") - // .assert_msg_has(0, "Conflicting type for polymorphic variable 'T'") - // .assert_occurs_at(1, "second_arg") - // .assert_msg_has(1, "inferred it to 's64'") - // .assert_occurs_at(2, "first_arg") - // .assert_msg_has(2, "inferred it to 's8'"); - // }); + Tester::new_single_source_expect_err( + "function call inference mismatch", + " + func poly(T a, T b) -> s32 { return 0; } + func call() -> s32 { + s8 first_arg = 5; + s64 second_arg = 2; + return poly(first_arg, second_arg); + } + " + ).error(|e| { e + .assert_num(3) + .assert_ctx_has(0, "poly(first_arg, second_arg)") + .assert_occurs_at(0, "poly") + .assert_msg_has(0, "Conflicting type for polymorphic variable 'T'") + .assert_occurs_at(1, "second_arg") + .assert_msg_has(1, "inferred it to 's64'") + .assert_occurs_at(2, "first_arg") + .assert_msg_has(2, "inferred it to 's8'"); + }); + + Tester::new_single_source_expect_err( + "struct literal inference mismatch", + " + struct Pair{ T first, T second } + func call() -> s32 { + s8 first_arg = 5; + s64 second_arg = 2; + auto pair = Pair{ first: first_arg, second: second_arg }; + return 3; + } + " + ).error(|e| { e + .assert_num(3) + .assert_ctx_has(0, "Pair{ first: first_arg, second: second_arg }") + .assert_occurs_at(0, "Pair{") + .assert_msg_has(0, "Conflicting type for polymorphic variable 'T'") + .assert_occurs_at(1, "second_arg") + .assert_msg_has(1, "inferred it to 's64'") + .assert_occurs_at(2, "first_arg") + .assert_msg_has(2, "inferred it to 's8'"); + }); // Cannot really test literal inference error, but this comes close Tester::new_single_source_expect_err( diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index 4fcdf8f1ccf5bed0b8548b4590d393395a6194af..6bed679520f0ac6a2d6f63cedaa64e30e99164b2 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -1197,6 +1197,9 @@ fn seek_expr_in_expr bool>(heap: &Heap, start: ExpressionI } None }, + Expression::Cast(expr) => { + seek_expr_in_expr(heap, expr.subject, f) + } Expression::Call(expr) => { for arg in &expr.arguments { if let Some(id) = seek_expr_in_expr(heap, *arg, f) {