From c2e3074a729b3dd8d0fd4df4d3c8aaa4e1400e94 2022-04-13 13:35:29 From: mh Date: 2022-04-13 13:35:29 Subject: [PATCH] Implement byte string, TCP socket, HTTP request test. Fix escape character parsing. Refactor component code --- diff --git a/src/collections/scoped_buffer.rs b/src/collections/scoped_buffer.rs index da789984728f3c4796ea46d3715fc91395ab10fd..956d77b87b9fd6351d77e29b7bf54454b59683c0 100644 --- a/src/collections/scoped_buffer.rs +++ b/src/collections/scoped_buffer.rs @@ -176,7 +176,10 @@ impl std::ops::IndexMut for ScopedSection { } } -#[cfg(debug_assertions)] +// note: this `Drop` impl used to be debug-only, requiring the programmer to +// call `into_vec` or `forget`. But this is rather error prone. So we'll check +// in debug mode, but always truncate in release mode (even though this is a +// noop in most cases). impl Drop for ScopedSection { fn drop(&mut self) { let vec = unsafe{&mut *self.inner}; diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index 87ea2bb9e079afc72d907ac79c37d0a0b9b90ad2..3bcc0b525893c980586cd57246693feb65892865 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -1896,6 +1896,7 @@ pub enum Literal { True, False, Character(char), + Bytestring(Vec), String(StringRef<'static>), Integer(LiteralInteger), Struct(LiteralStruct), diff --git a/src/protocol/ast_writer.rs b/src/protocol/ast_writer.rs index 6411ae43cdff50018a3480bfe57184099ff812aa..1b376a16f7f674291c2b55efa010ea1b25ff23d9 100644 --- a/src/protocol/ast_writer.rs +++ b/src/protocol/ast_writer.rs @@ -365,8 +365,12 @@ impl ASTWriter { self.write_variable(heap, *variable_id, indent3); } - self.kv(indent2).with_s_key("Body"); - self.write_stmt(heap, def.body.upcast(), indent3); + if def.source.is_builtin() { + self.kv(indent2).with_s_key("Body").with_s_val("Builtin"); + } else { + self.kv(indent2).with_s_key("Body"); + self.write_stmt(heap, def.body.upcast(), indent3); + } }, } } @@ -685,6 +689,11 @@ impl ASTWriter { Literal::True => { val.with_s_val("true"); }, Literal::False => { val.with_s_val("false"); }, Literal::Character(data) => { val.with_disp_val(data); }, + Literal::Bytestring(bytes) => { + // Bytestrings are ASCII, so just convert back + let string = String::from_utf8_lossy(bytes.as_slice()); + val.with_disp_val(&string); + }, Literal::String(data) => { // Stupid hack let string = String::from(data.as_str()); diff --git a/src/protocol/eval/executor.rs b/src/protocol/eval/executor.rs index 55ee38824b1f59d398d9a2d3873e956cdf4c8c18..51fd02d09163c89b2881c71691a6931795858990 100644 --- a/src/protocol/eval/executor.rs +++ b/src/protocol/eval/executor.rs @@ -135,7 +135,7 @@ impl Frame { // Here we only care about literals that have subexpressions match &expr.value { Literal::Null | Literal::True | Literal::False | - Literal::Character(_) | Literal::String(_) | + Literal::Character(_) | Literal::Bytestring(_) | Literal::String(_) | Literal::Integer(_) | Literal::Enum(_) => { // No subexpressions }, @@ -514,6 +514,16 @@ impl Prompt { Literal::True => Value::Bool(true), Literal::False => Value::Bool(false), Literal::Character(lit_value) => Value::Char(*lit_value), + Literal::Bytestring(lit_value) => { + let heap_pos = self.store.alloc_heap(); + let values = &mut self.store.heap_regions[heap_pos as usize].values; + debug_assert!(values.is_empty()); + values.reserve(lit_value.len()); + for byte in lit_value { + values.push(Value::UInt8(*byte)); + } + Value::Array(heap_pos) + } Literal::String(lit_value) => { let heap_pos = self.store.alloc_heap(); let values = &mut self.store.heap_regions[heap_pos as usize].values; @@ -718,6 +728,8 @@ impl Prompt { // Convert the runtime-variant of a string // into an actual string. let value = cur_frame.expr_values.pop_front().unwrap(); + let mut is_literal_string = value.get_heap_pos().is_some(); + let value = self.store.maybe_read_ref(&value); let value_heap_pos = value.as_string(); let elements = &self.store.heap_regions[value_heap_pos as usize].values; @@ -728,7 +740,10 @@ impl Prompt { // Drop the heap-allocated value from the // store - self.store.drop_heap_pos(value_heap_pos); + if is_literal_string { + self.store.drop_heap_pos(value_heap_pos); + } + println!("{}", message); }, Method::SelectStart => { diff --git a/src/protocol/parser/pass_definitions.rs b/src/protocol/parser/pass_definitions.rs index e8c8b43705f8d4329dee6d197286f3467dec7a90..f1daca6c3081615e0dee6cdde1a16065d2eaf752 100644 --- a/src/protocol/parser/pass_definitions.rs +++ b/src/protocol/parser/pass_definitions.rs @@ -1523,11 +1523,25 @@ impl PassDefinitions { } else if next == Some(TokenKind::Integer) { let (literal, span) = consume_integer_literal(&module.source, iter, &mut self.buffer)?; + ctx.heap.alloc_literal_expression(|this| LiteralExpression { + this, + span, + value: Literal::Integer(LiteralInteger { unsigned_value: literal, negated: false }), + parent: ExpressionParent::None, + type_index: -1, + }).upcast() + } else if next == Some(TokenKind::Bytestring) { + let span = consume_bytestring_literal(&module.source, iter, &mut self.buffer)?; + let mut bytes = Vec::with_capacity(self.buffer.len()); + for byte in self.buffer.as_bytes().iter().copied() { + bytes.push(byte); + } + ctx.heap.alloc_literal_expression(|this| LiteralExpression{ this, span, - value: Literal::Integer(LiteralInteger{ unsigned_value: literal, negated: false }), + value: Literal::Bytestring(bytes), parent: ExpressionParent::None, - type_index: -1, + type_index: -1 }).upcast() } else if next == Some(TokenKind::String) { let span = consume_string_literal(&module.source, iter, &mut self.buffer)?; diff --git a/src/protocol/parser/pass_tokenizer.rs b/src/protocol/parser/pass_tokenizer.rs index 07f7dbc2c2eef970825fd2681526cac56e3e701f..c611c9c4dc6b79e5d39c2f23c19a8742c7f8db3c 100644 --- a/src/protocol/parser/pass_tokenizer.rs +++ b/src/protocol/parser/pass_tokenizer.rs @@ -41,6 +41,8 @@ impl PassTokenizer { if is_char_literal_start(c) { self.consume_char_literal(source, target)?; + } else if is_bytestring_literal_start(c, source) { + self.consume_bytestring_literal(source, target)?; } else if is_string_literal_start(c) { self.consume_string_literal(source, target)?; } else if is_identifier_start(c) { @@ -356,41 +358,21 @@ impl PassTokenizer { Ok(()) } - fn consume_string_literal(&mut self, source: &mut InputSource, target: &mut TokenBuffer) -> Result<(), ParseError> { + fn consume_bytestring_literal(&mut self, source: &mut InputSource, target: &mut TokenBuffer) -> Result<(), ParseError> { let begin_pos = source.pos(); - - // Consume the leading double quotes - debug_assert!(source.next().unwrap() == b'"'); + debug_assert!(source.next().unwrap() == b'b'); source.consume(); - let mut prev_char = b'"'; - while let Some(c) = source.next() { - if !c.is_ascii() { - return Err(ParseError::new_error_str_at_pos(source, source.pos(), "non-ASCII character in string literal")); - } - - source.consume(); - if c == b'"' && prev_char != b'\\' { - // Unescaped string terminator - prev_char = c; - break; - } - - if prev_char == b'\\' && c == b'\\' { - // Escaped backslash, set prev_char to bogus to not conflict - // with escaped-" and unterminated string literal detection. - prev_char = b'\0'; - } else { - prev_char = c; - } - } + let end_pos = self.consume_ascii_string(begin_pos, source)?; + target.tokens.push(Token::new(TokenKind::Bytestring, begin_pos)); + target.tokens.push(Token::new(TokenKind::SpanEnd, end_pos)); - if prev_char != b'"' { - // Unterminated string literal - return Err(ParseError::new_error_str_at_pos(source, begin_pos, "encountered unterminated string literal")); - } + Ok(()) + } - let end_pos = source.pos(); + fn consume_string_literal(&mut self, source: &mut InputSource, target: &mut TokenBuffer) -> Result<(), ParseError> { + let begin_pos = source.pos(); + let end_pos = self.consume_ascii_string(begin_pos, source)?; target.tokens.push(Token::new(TokenKind::String, begin_pos)); target.tokens.push(Token::new(TokenKind::SpanEnd, end_pos)); @@ -548,6 +530,44 @@ impl PassTokenizer { Ok(()) } + // Consumes the ascii string (including leading and trailing quotation + // marks) and returns the input position *after* the last quotation mark (or + // an error, if something went wrong). + fn consume_ascii_string(&self, begin_pos: InputPosition, source: &mut InputSource) -> Result { + debug_assert!(source.next().unwrap() == b'"'); + source.consume(); + + let mut prev_char = b'"'; + while let Some(c) = source.next() { + if !c.is_ascii() { + return Err(ParseError::new_error_str_at_pos(source, source.pos(), "non-ASCII character in string literal")); + } + + source.consume(); + if c == b'"' && prev_char != b'\\' { + // Unescaped string terminator + prev_char = c; + break; + } + + if prev_char == b'\\' && c == b'\\' { + // Escaped backslash, set prev_char to bogus to not conflict + // with escaped-" and unterminated string literal detection. + prev_char = b'\0'; + } else { + prev_char = c; + } + } + + if prev_char != b'"' { + // Unterminated string literal + return Err(ParseError::new_error_str_at_pos(source, begin_pos, "encountered unterminated string literal")); + } + + let end_pos = source.pos(); + return Ok(end_pos) + } + // Consumes whitespace and returns whether or not the whitespace contained // a newline. fn consume_whitespace(&self, source: &mut InputSource) -> bool { @@ -607,22 +627,32 @@ fn demarks_symbol(ident: &[u8]) -> bool { ident == KW_COMPOSITE } +#[inline] fn demarks_import(ident: &[u8]) -> bool { return ident == KW_IMPORT; } +#[inline] fn is_whitespace(c: u8) -> bool { c.is_ascii_whitespace() } +#[inline] fn is_char_literal_start(c: u8) -> bool { return c == b'\''; } +#[inline] +fn is_bytestring_literal_start(c: u8, source: &InputSource) -> bool { + return c == b'b' && source.lookahead(1) == Some(b'"'); +} + +#[inline] fn is_string_literal_start(c: u8) -> bool { return c == b'"'; } +#[inline] fn is_pragma_start_or_pound(c: u8) -> bool { return c == b'#'; } @@ -642,6 +672,7 @@ fn is_identifier_remaining(c: u8) -> bool { c == b'_' } +#[inline] fn is_integer_literal_start(c: u8) -> bool { return c >= b'0' && c <= b'9'; } diff --git a/src/protocol/parser/pass_typing.rs b/src/protocol/parser/pass_typing.rs index e3517585e4766375ee36dba416442a668fc97c73..3d8c5da010bc001ef8f05a1879bab46bc697f8fe 100644 --- a/src/protocol/parser/pass_typing.rs +++ b/src/protocol/parser/pass_typing.rs @@ -64,6 +64,7 @@ const VOID_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::Void ]; const MESSAGE_TEMPLATE: [InferenceTypePart; 2] = [ InferenceTypePart::Message, InferenceTypePart::UInt8 ]; const BOOL_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::Bool ]; const CHARACTER_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::Character ]; +const BYTEARRAY_TEMPLATE: [InferenceTypePart; 2] = [ InferenceTypePart::Array, InferenceTypePart::UInt8 ]; const STRING_TEMPLATE: [InferenceTypePart; 2] = [ InferenceTypePart::String, InferenceTypePart::Character ]; const NUMBERLIKE_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::NumberLike ]; const INTEGERLIKE_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::IntegerLike ]; @@ -1726,6 +1727,10 @@ impl PassTyping { let node = &mut self.infer_nodes[self_index]; node.inference_rule = InferenceRule::MonoTemplate(InferenceRuleTemplate::new_forced(&CHARACTER_TEMPLATE)); }, + Literal::Bytestring(_) => { + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::MonoTemplate(InferenceRuleTemplate::new_forced(&BYTEARRAY_TEMPLATE)); + }, Literal::String(_) => { let node = &mut self.infer_nodes[self_index]; node.inference_rule = InferenceRule::MonoTemplate(InferenceRuleTemplate::new_forced(&STRING_TEMPLATE)); @@ -1876,6 +1881,7 @@ impl PassTyping { expr_ids.forget(); let argument_indices = expr_indices.into_vec(); + let node = &mut self.infer_nodes[self_index]; node.poly_data_index = extra_index; node.inference_rule = InferenceRule::CallExpr(InferenceRuleCallExpr{ @@ -2146,7 +2152,10 @@ impl PassTyping { use InferenceRule as IR; let node = &self.infer_nodes[node_index]; - match &node.inference_rule { + debug_log!("Progressing inference node (node_index: {})", node_index); + debug_log!(" * Expression ID: {}", node.expr_id.index); + debug_log!(" * Expression type pre : {}", node.expr_type.display_name(&ctx.heap)); + let result = match &node.inference_rule { IR::Noop => unreachable!(), IR::MonoTemplate(_) => @@ -2183,7 +2192,10 @@ impl PassTyping { self.progress_inference_rule_call_expr(ctx, node_index), IR::VariableExpr(_) => self.progress_inference_rule_variable_expr(ctx, node_index), - } + }; + + debug_log!(" * Expression type post: {}", self.infer_nodes[node_index].expr_type.display_name(&ctx.heap)); + return result; } fn progress_inference_rule_mono_template(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { @@ -2674,6 +2686,8 @@ impl PassTyping { if progress_literal_1 || progress_literal_2 { self.queue_node_parent(node_index); } poly_progress_section.forget(); + element_indices_section.forget(); + self.finish_polydata_constraint(node_index); return Ok(()); } @@ -2827,6 +2841,8 @@ impl PassTyping { let node_expr_id = node.expr_id; let rule = node.inference_rule.as_call_expr(); + debug_log!("Progressing call expression inference rule (node index {})", node_index); + let mut poly_progress_section = self.poly_progress_buffer.start_section(); let argument_node_indices = self.index_buffer.start_section_initialized(&rule.argument_indices); @@ -2834,21 +2850,29 @@ impl PassTyping { // out the polymorphic variables for (argument_index, argument_node_index) in argument_node_indices.iter_copied().enumerate() { let argument_expr_id = self.infer_nodes[argument_node_index].expr_id; + debug_log!(" * Argument {}: Provided by node index {}", argument_index, argument_node_index); + debug_log!(" * --- Pre: {}", self.infer_nodes[argument_node_index].expr_type.display_name(&ctx.heap)); let (_, progress_argument) = self.apply_polydata_equal2_constraint( ctx, node_index, argument_expr_id, "argument's", PolyDataTypeIndex::Associated(argument_index), 0, argument_node_index, 0, &mut poly_progress_section )?; + debug_log!(" * --- Post: {}", self.infer_nodes[argument_node_index].expr_type.display_name(&ctx.heap)); + debug_log!(" * --- Progression: {}", progress_argument); if progress_argument { self.queue_node(argument_node_index); } } // Same for the return type. + debug_log!(" * Return type: Provided by node index {}", node_index); + debug_log!(" * --- Pre: {}", self.infer_nodes[node_index].expr_type.display_name(&ctx.heap)); let (_, progress_call_1) = self.apply_polydata_equal2_constraint( ctx, node_index, node_expr_id, "return", PolyDataTypeIndex::Returned, 0, node_index, 0, &mut poly_progress_section )?; + debug_log!(" * --- Post: {}", self.infer_nodes[node_index].expr_type.display_name(&ctx.heap)); + debug_log!(" * --- Progression: {}", progress_call_1); // We will now apply any progression in the polymorphic variable type // back to the arguments. diff --git a/src/protocol/parser/pass_validation_linking.rs b/src/protocol/parser/pass_validation_linking.rs index 5456872357ce071f7b357bd86b4a1440af4d63c9..ae2f403e976a52ca8b56971f1b892e21f74264a8 100644 --- a/src/protocol/parser/pass_validation_linking.rs +++ b/src/protocol/parser/pass_validation_linking.rs @@ -902,7 +902,8 @@ impl Visitor for PassValidationLinking { match &mut literal_expr.value { Literal::Null | Literal::True | Literal::False | - Literal::Character(_) | Literal::String(_) | Literal::Integer(_) => { + Literal::Character(_) | Literal::Bytestring(_) | Literal::String(_) | + Literal::Integer(_) => { // Just the parent has to be set, done above }, Literal::Struct(literal) => { diff --git a/src/protocol/parser/token_parsing.rs b/src/protocol/parser/token_parsing.rs index 0142174de5614bf9ea6e795127785fac1f2fece8..28663793ff335f24a9db786c9514ece91f92536b 100644 --- a/src/protocol/parser/token_parsing.rs +++ b/src/protocol/parser/token_parsing.rs @@ -390,28 +390,57 @@ pub(crate) fn consume_character_literal( return Err(ParseError::new_error_str_at_span(source, span, "too many characters in character literal")) } +/// Consumes a bytestring literal: a string interpreted as a byte array. See +/// `consume_string_literal` for further remarks. +pub(crate) fn consume_bytestring_literal( + source: &InputSource, iter: &mut TokenIter, buffer: &mut String +) -> Result { + // Retrieve string span, adjust to remove the leading "b" character + if Some(TokenKind::Bytestring) != iter.next() { + return Err(ParseError::new_error_str_at_pos(source, iter.last_valid_pos(), "expected a bytestring literal")); + } + + let span = iter.next_span(); + iter.consume(); + debug_assert_eq!(source.section_at_pos(span.begin, span.begin.with_offset(1)), b"b"); + + // Parse into buffer + let text_span = InputSpan::from_positions(span.begin.with_offset(1), span.end); + parse_escaped_string(source, text_span, buffer)?; + + return Ok(span); +} + /// Consumes a string literal. We currently support a limited number of /// backslash-escaped characters. Note that the result is stored in the /// buffer. pub(crate) fn consume_string_literal( source: &InputSource, iter: &mut TokenIter, buffer: &mut String ) -> Result { + // Retrieve string span from token stream if Some(TokenKind::String) != iter.next() { return Err(ParseError::new_error_str_at_pos(source, iter.last_valid_pos(), "expected a string literal")); } - buffer.clear(); let span = iter.next_span(); iter.consume(); - let text = source.section_at_span(span); + // Parse into buffer + parse_escaped_string(source, span, buffer)?; + + return Ok(span); +} + +fn parse_escaped_string(source: &InputSource, text_span: InputSpan, buffer: &mut String) -> Result<(), ParseError> { + let text = source.section_at_span(text_span); if !text.is_ascii() { - return Err(ParseError::new_error_str_at_span(source, span, "expected an ASCII string literal")); + return Err(ParseError::new_error_str_at_span(source, text_span, "expected an ASCII string literal")); } debug_assert_eq!(text[0], b'"'); // here as kind of a reminder: the span includes the bounding quotation marks debug_assert_eq!(text[text.len() - 1], b'"'); + buffer.clear(); buffer.reserve(text.len() - 2); let mut was_escape = false; @@ -419,9 +448,9 @@ pub(crate) fn consume_string_literal( let cur = text[idx]; let is_escape = cur == b'\\'; if was_escape { - let to_push = parse_escaped_character(source, span, cur)?; + let to_push = parse_escaped_character(source, text_span, cur)?; buffer.push(to_push); - } else { + } else if !is_escape { buffer.push(cur as char); } @@ -434,9 +463,10 @@ pub(crate) fn consume_string_literal( debug_assert!(!was_escape); // because otherwise we couldn't have ended the string literal - Ok(span) + return Ok(()); } +#[inline] fn parse_escaped_character(source: &InputSource, literal_span: InputSpan, v: u8) -> Result { let result = match v { b'r' => '\r', diff --git a/src/protocol/parser/tokens.rs b/src/protocol/parser/tokens.rs index d64b9572963b0fd63ea7ee850c2d4bae92e39f1a..2c1de3259841299875c763f62d3ab2655d576ea0 100644 --- a/src/protocol/parser/tokens.rs +++ b/src/protocol/parser/tokens.rs @@ -12,6 +12,7 @@ pub enum TokenKind { Ident, // regular identifier Pragma, // identifier with prefixed `#`, range includes `#` Integer, // integer literal + Bytestring, // string literal, interpreted as byte array, range includes 'b"' String, // string literal, range includes `"` Character, // character literal, range includes `'` LineComment, // line comment, range includes leading `//`, but not newline @@ -152,7 +153,8 @@ impl TokenKind { TK::ShiftLeftEquals => "<<=", TK::ShiftRightEquals => ">>=", // Lets keep these in explicitly for now, in case we want to add more symbols - TK::Ident | TK::Pragma | TK::Integer | TK::String | TK::Character | + TK::Ident | TK::Pragma | TK::Integer | + TK::Bytestring | TK::String | TK::Character | TK::LineComment | TK::BlockComment | TK::SpanEnd => unreachable!(), } } diff --git a/src/protocol/tests/eval_operators.rs b/src/protocol/tests/eval_operators.rs index 9f16b0bb9bce077911bcca37746b6b44f9d491ac..ac29e2cf83b875ecea143f88620eebec700f0788 100644 --- a/src/protocol/tests/eval_operators.rs +++ b/src/protocol/tests/eval_operators.rs @@ -230,8 +230,10 @@ func foo() -> bool { auto res2 = perform_concatenate(left, right); auto expected = \"Darth Vader, but also Anakin Skywalker\"; + print(res1); + print(expected); return - res1 == expected && + res1 == expected; // && res2 == \"Darth Vader, but also Anakin Skywalker\" && res1 != \"This kind of thing\" && res2 != \"Another likewise kind of thing\"; } diff --git a/src/protocol/tests/parser_literals.rs b/src/protocol/tests/parser_literals.rs index de343928abfb66f8f8b6cc0399badfa476c205b4..9560cbfe1a33d149a16f4d475688938e841f86ae 100644 --- a/src/protocol/tests/parser_literals.rs +++ b/src/protocol/tests/parser_literals.rs @@ -69,6 +69,47 @@ fn test_string_literals() { ").error(|e| { e.assert_msg_has(0, "non-ASCII character in string literal"); }); } +#[test] +fn test_bytestring_literals() { + Tester::new_single_source_expect_ok("valid", " + func test() -> u8[] { + auto v1 = b\"Hello, world!\"; + auto v2 = b\"\\t\\r\\n\\\\\"; // why hello there, confusing thing + auto v3 = b\"\"; + return b\"No way, dude!\"; + } + ").for_function("test", |f| { f + .for_variable("v1", |v| { v.assert_concrete_type("u8[]"); }) + .for_variable("v2", |v| { v.assert_concrete_type("u8[]"); }) + .for_variable("v3", |v| { v.assert_concrete_type("u8[]"); }); + }); + + Tester::new_single_source_expect_err("unterminated simple", " + func test() -> u8[] { return b\"'; } + ").error(|e| { e + .assert_num(1) + .assert_occurs_at(0, "b\"") + .assert_msg_has(0, "unterminated"); + }); + + Tester::new_single_source_expect_err("unterminated with preceding escaped", " + func test() -> u8[] { return b\"\\\"; } + ").error(|e| { e + .assert_num(1) + .assert_occurs_at(0, "b\"\\") + .assert_msg_has(0, "unterminated"); + }); + + Tester::new_single_source_expect_err("invalid escaped character", " + func test() -> u8[] { return b\"\\y\"; } + ").error(|e| { e.assert_msg_has(0, "unsupported escape character 'y'"); }); + + // Note sure if this should always be in here... + Tester::new_single_source_expect_err("non-ASCII string", " + func test() -> u8[] { return b\"💧\"; } + ").error(|e| { e.assert_msg_has(0, "non-ASCII character in string literal"); }); +} + #[test] fn test_tuple_literals() { Tester::new_single_source_expect_ok("zero tuples", " diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index 6a321b67a212328b6ba181b77eb795a50b339bd0..0fc4005a2dc6edf5c69523ee351d28dace6724a1 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -60,6 +60,7 @@ impl Tester { pub(crate) fn compile(self) -> AstTesterResult { let mut parser = Parser::new().unwrap(); + for source in self.sources.into_iter() { let source = source.into_bytes(); let input_source = InputSource::new(String::from(""), source); diff --git a/src/runtime2/component/component.rs b/src/runtime2/component/component.rs index 828421de9704ad6aa50d13a770b5ba0fe6984cbd..a0cc811fcc5e0ee8cfbf52a3be015719545f2dbd 100644 --- a/src/runtime2/component/component.rs +++ b/src/runtime2/component/component.rs @@ -19,9 +19,16 @@ pub enum CompScheduling { /// Generic representation of a component (as viewed by a scheduler). pub(crate) trait Component { - /// Called upon the creation of the component. + /// Called upon the creation of the component. Note that the scheduler + /// context is officially running another component (the component that is + /// creating the new component). fn on_creation(&mut self, comp_id: CompId, sched_ctx: &SchedulerCtx); + /// Called when a component crashes or wishes to exit. So is not called + /// right before destruction, other components may still hold a handle to + /// the component and send it messages! + fn on_shutdown(&mut self, sched_ctx: &SchedulerCtx); + /// Called if the component is created by another component and the messages /// are being transferred between the two. fn adopt_message(&mut self, comp_ctx: &mut CompCtx, message: DataMessage); @@ -229,6 +236,41 @@ pub(crate) fn default_handle_incoming_data_message( } } +/// Default handling that has been received through a `get`. Will check if any +/// more messages are waiting, and if the corresponding port was blocked because +/// of full buffers (hence, will use the control layer to make sure the peer +/// will become unblocked). +pub(crate) fn default_handle_received_data_message( + targeted_port: PortId, slot: &mut Option, inbox_backup: &mut Vec, + comp_ctx: &mut CompCtx, sched_ctx: &SchedulerCtx, control: &mut ControlLayer +) { + debug_assert!(slot.is_none()); // because we've just received from it + + // Check if there are any more messages in the backup buffer + let port_handle = comp_ctx.get_port_handle(targeted_port); + let port_info = comp_ctx.get_port(port_handle); + for message_index in 0..inbox_backup.len() { + let message = &inbox_backup[message_index]; + if message.data_header.target_port == targeted_port { + // One more message, place it in the slot + let message = inbox_backup.remove(message_index); + debug_assert!(port_info.state.is_blocked()); // since we're removing another message from the backup + *slot = Some(message); + + return; + } + } + + // Did not have any more messages, so if we were blocked, then we need to + // unblock the port now (and inform the peer of this unblocking) + if port_info.state == PortState::BlockedDueToFullBuffers { + comp_ctx.set_port_state(port_handle, PortState::Open); + let (peer_handle, message) = control.cancel_port_blocking(comp_ctx, port_handle); + let peer_info = comp_ctx.get_peer(peer_handle); + peer_info.handle.send_message(&sched_ctx.runtime, Message::Control(message), true); + } +} + /// Handles control messages in the default way. Note that this function may /// take a lot of actions in the name of the caller: pending messages may be /// sent, ports may become blocked/unblocked, etc. So the execution diff --git a/src/runtime2/component/component_internet.rs b/src/runtime2/component/component_internet.rs index 710de32baa0fc6e0f1c11e39b7b6bf67bf87b200..3b9c9ca77e42b135a396a2965dbfcba32f8b81cb 100644 --- a/src/runtime2/component/component_internet.rs +++ b/src/runtime2/component/component_internet.rs @@ -33,6 +33,7 @@ enum SyncState { Getting, Putting, FinishSync, + FinishSyncThenQuit, } pub struct ComponentTcpClient { @@ -81,6 +82,13 @@ impl Component for ComponentTcpClient { } } + fn on_shutdown(&mut self, sched_ctx: &SchedulerCtx) { + if let Some(poll_ticket) = self.poll_ticket.take() { + sched_ctx.polling.unregister(poll_ticket) + .expect("unregistering tcp component"); + } + } + fn adopt_message(&mut self, _comp_ctx: &mut CompCtx, message: DataMessage) { if self.inbox_main.is_none() { self.inbox_main = Some(message); @@ -111,7 +119,7 @@ impl Component for ComponentTcpClient { } fn run(&mut self, sched_ctx: &mut SchedulerCtx, comp_ctx: &mut CompCtx) -> Result { - sched_ctx.log(&format!("Running component ComponentTcpClient (mode: {:?}", self.exec_state.mode)); + sched_ctx.log(&format!("Running component ComponentTcpClient (mode: {:?}, sync state: {:?})", self.exec_state.mode, self.sync_state)); match self.exec_state.mode { CompMode::BlockedSelect => { @@ -122,10 +130,16 @@ impl Component for ComponentTcpClient { // When in non-sync mode match &mut self.socket_state { SocketState::Connected(_socket) => { - // Always move into the sync-state - self.sync_state = SyncState::AwaitingCmd; - self.consensus.notify_sync_start(comp_ctx); - self.exec_state.mode = CompMode::Sync; + if self.sync_state == SyncState::FinishSyncThenQuit { + // Previous request was to let the component shut down + self.exec_state.mode = CompMode::StartExit; + } else { + // Reset for a new request + self.sync_state = SyncState::AwaitingCmd; + self.consensus.notify_sync_start(comp_ctx); + self.exec_state.mode = CompMode::Sync; + } + return Ok(CompScheduling::Immediate); }, SocketState::Error => { // Could potentially send an error message to the @@ -139,9 +153,17 @@ impl Component for ComponentTcpClient { // When in sync mode: wait for a command to come in match self.sync_state { SyncState::AwaitingCmd => { - if let Some(message) = self.inbox_backup.pop() { + if let Some(message) = &self.inbox_main { + self.consensus.handle_incoming_data_message(comp_ctx, &message); if self.consensus.try_receive_data_message(sched_ctx, comp_ctx, &message) { // Check which command we're supposed to execute. + let message = self.inbox_main.take().unwrap(); + let target_port_id = message.data_header.target_port; + component::default_handle_received_data_message( + target_port_id, &mut self.inbox_main, &mut self.inbox_backup, + comp_ctx, sched_ctx, &mut self.control + ); + let (tag_value, embedded_heap_pos) = message.content.values[0].as_union(); if tag_value == self.input_union_send_tag_value { // Retrieve bytes from the message @@ -163,11 +185,14 @@ impl Component for ComponentTcpClient { return Ok(CompScheduling::Immediate); } else if tag_value == self.input_union_finish_tag_value { // Component requires us to end the sync round - let decision = self.consensus.notify_sync_end(sched_ctx, comp_ctx); - component::default_handle_sync_decision(&mut self.exec_state, decision, &mut self.consensus); + self.sync_state = SyncState::FinishSync; + return Ok(CompScheduling::Immediate); } else if tag_value == self.input_union_shutdown_tag_value { // Component wants to close the connection - todo!("implement clean shutdown, don't forget to unregister to poll ticket"); + self.sync_state = SyncState::FinishSyncThenQuit; + return Ok(CompScheduling::Immediate); + } else { + unreachable!("got tag_value {}", tag_value) } } else { todo!("handle sync failure due to message deadlock"); @@ -202,7 +227,9 @@ impl Component for ComponentTcpClient { // If here then we're done putting the data, we can // finish the sync round let decision = self.consensus.notify_sync_end(sched_ctx, comp_ctx); + self.exec_state.mode = CompMode::SyncEnd; component::default_handle_sync_decision(&mut self.exec_state, decision, &mut self.consensus); + return Ok(CompScheduling::Immediate); }, SyncState::Getting => { // We're going to try and receive a single message. If @@ -211,31 +238,30 @@ impl Component for ComponentTcpClient { const BUFFER_SIZE: usize = 1024; // TODO: Move to config let socket = self.socket_state.get_socket(); - debug_assert!(self.byte_buffer.is_empty()); self.byte_buffer.resize(BUFFER_SIZE, 0); match socket.receive(&mut self.byte_buffer) { Ok(num_received) => { self.byte_buffer.resize(num_received, 0); let message_content = self.bytes_to_data_message_content(&self.byte_buffer); let scheduling = component::default_send_data_message(&mut self.exec_state, self.pdl_output_port_id, message_content, sched_ctx, &mut self.consensus, comp_ctx); - self.sync_state = SyncState::FinishSync; + self.sync_state = SyncState::AwaitingCmd; return Ok(scheduling); }, Err(err) => { if err.kind() == IoErrorKind::WouldBlock { return Ok(CompScheduling::Sleep); // wait until polled } else { - todo!("handle socket.receive error {:?}", err); + todo!("handle socket.receive error {:?}", err) } } } }, - SyncState::FinishSync => { + SyncState::FinishSync | SyncState::FinishSyncThenQuit => { let decision = self.consensus.notify_sync_end(sched_ctx, comp_ctx); self.exec_state.mode = CompMode::SyncEnd; component::default_handle_sync_decision(&mut self.exec_state, decision, &mut self.consensus); return Ok(CompScheduling::Requeue); - } + }, } }, CompMode::BlockedGet => { @@ -252,8 +278,6 @@ impl Component for ComponentTcpClient { CompMode::Exit => return Ok(component::default_handle_exit(&self.exec_state)), } - - return Ok(CompScheduling::Immediate); } } @@ -280,7 +304,7 @@ impl ComponentTcpClient { let socket = SocketTcpClient::new(ip_address, port); if let Err(socket) = socket { - todo!("friendly error reporting: failed to open socket {:?}", socket); + todo!("friendly error reporting: failed to open socket (reason: {:?})", socket); } return Self{ diff --git a/src/runtime2/component/component_pdl.rs b/src/runtime2/component/component_pdl.rs index 338d05bde302f39e2d796e59f12173102be2b09d..dd3081ee1d12512335b33c706a690cfa21fc7612 100644 --- a/src/runtime2/component/component_pdl.rs +++ b/src/runtime2/component/component_pdl.rs @@ -227,6 +227,10 @@ impl Component for CompPDL { // Intentionally empty } + fn on_shutdown(&mut self, _sched_ctx: &SchedulerCtx) { + // Intentionally empty + } + fn adopt_message(&mut self, comp_ctx: &mut CompCtx, message: DataMessage) { let port_handle = comp_ctx.get_port_handle(message.data_header.target_port); let port_index = comp_ctx.get_port_index(port_handle); @@ -238,7 +242,7 @@ impl Component for CompPDL { } fn handle_message(&mut self, sched_ctx: &mut SchedulerCtx, comp_ctx: &mut CompCtx, mut message: Message) { - sched_ctx.log(&format!("handling message: {:#?}", message)); + // sched_ctx.log(&format!("handling message: {:?}", message)); if let Some(new_target) = self.control.should_reroute(&mut message) { let mut target = sched_ctx.runtime.get_component_public(new_target); // TODO: @NoDirectHandle target.send_message(&sched_ctx.runtime, message, false); // not waking up: we schedule once we've received all PortPeerChanged Acks @@ -313,7 +317,10 @@ impl Component for CompPDL { // Message was received. Make sure any blocked peers and // pending messages are handled. let message = self.inbox_main[port_index].take().unwrap(); - self.handle_received_data_message(sched_ctx, comp_ctx, port_handle); + component::default_handle_received_data_message( + port_id, &mut self.inbox_main[port_index], &mut self.inbox_backup, + comp_ctx, sched_ctx, &mut self.control + ); self.exec_ctx.stmt = ExecStmt::PerformedGet(message.content); return Ok(CompScheduling::Immediate); @@ -551,38 +558,6 @@ impl CompPDL { } } - /// Handles when a message has been handed off from the inbox to the PDL - /// code. We check to see if there are more messages waiting and, if not, - /// then we handle the case where the port might have been blocked - /// previously. - fn handle_received_data_message(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, port_handle: LocalPortHandle) { - let port_index = comp_ctx.get_port_index(port_handle); - debug_assert!(self.inbox_main[port_index].is_none()); // this function should be called after the message is taken out - - // Check for any more messages - let port_info = comp_ctx.get_port(port_handle); - for message_index in 0..self.inbox_backup.len() { - let message = &self.inbox_backup[message_index]; - if message.data_header.target_port == port_info.self_id { - // One more message for this port - let message = self.inbox_backup.remove(message_index); - debug_assert!(comp_ctx.get_port(port_handle).state.is_blocked()); // since we had >1 message on the port - self.inbox_main[port_index] = Some(message); - - return; - } - } - - // Did not have any more messages. So if we were blocked, then we need - // to send the "unblock" message. - if port_info.state == PortState::BlockedDueToFullBuffers { - comp_ctx.set_port_state(port_handle, PortState::Open); - let (peer_handle, message) = self.control.cancel_port_blocking(comp_ctx, port_handle); - let peer_info = comp_ctx.get_peer(peer_handle); - peer_info.handle.send_message(&sched_ctx.runtime, Message::Control(message), true); - } - } - fn handle_incoming_sync_message(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, message: SyncMessage) { let decision = self.consensus.receive_sync_message(sched_ctx, comp_ctx, message); self.handle_sync_decision(sched_ctx, comp_ctx, decision); @@ -618,13 +593,13 @@ impl CompPDL { let other_proc = &sched_ctx.runtime.protocol.heap[definition_id]; let self_proc = &sched_ctx.runtime.protocol.heap[self.prompt.frames[0].definition]; - dbg_code!({ - sched_ctx.log(&format!( - "DEBUG: Comp '{}' (ID {:?}) is creating comp '{}' (ID {:?})", - self_proc.identifier.value.as_str(), creator_ctx.id, - other_proc.identifier.value.as_str(), reservation.id() - )); - }); + // dbg_code!({ + // sched_ctx.log(&format!( + // "DEBUG: Comp '{}' (ID {:?}) is creating comp '{}' (ID {:?})", + // self_proc.identifier.value.as_str(), creator_ctx.id, + // other_proc.identifier.value.as_str(), reservation.id() + // )); + // }); // Take all the ports ID that are in the `args` (and currently belong to // the creator component) and translate them into new IDs that are diff --git a/src/runtime2/component/component_random.rs b/src/runtime2/component/component_random.rs index 1e68695ba53e3d240857ae240639883b216ea05b..dce5d709a86cf6b0e73d63bef292113aa38f691c 100644 --- a/src/runtime2/component/component_random.rs +++ b/src/runtime2/component/component_random.rs @@ -27,8 +27,9 @@ pub struct ComponentRandomU32 { } impl Component for ComponentRandomU32 { - fn on_creation(&mut self, _id: CompId, _sched_ctx: &SchedulerCtx) { - } + fn on_creation(&mut self, _id: CompId, _sched_ctx: &SchedulerCtx) {} + + fn on_shutdown(&mut self, sched_ctx: &SchedulerCtx) {} fn adopt_message(&mut self, _comp_ctx: &mut CompCtx, _message: DataMessage) { // Impossible since this component does not have any input ports in its diff --git a/src/runtime2/runtime.rs b/src/runtime2/runtime.rs index 80865946a82159b688217f614d29ba7a73311f6c..9da208a48a371a354cd6ec2d7aece2ff4a5fdf21 100644 --- a/src/runtime2/runtime.rs +++ b/src/runtime2/runtime.rs @@ -236,7 +236,7 @@ impl Drop for Runtime { pub(crate) struct RuntimeInner { pub protocol: ProtocolDescription, components: ComponentStore, - work_queue: Mutex>, + work_queue: Mutex>, // TODO: should be MPMC queue work_condvar: Condvar, active_elements: AtomicU32, // active components and APIs (i.e. component creators) } diff --git a/src/runtime2/scheduler.rs b/src/runtime2/scheduler.rs index 708d4c9ec6571ce57cc104abb50d84ae80f71920..d3159c82aec89f2d63ea6b11a71917be3a285417 100644 --- a/src/runtime2/scheduler.rs +++ b/src/runtime2/scheduler.rs @@ -75,7 +75,10 @@ impl Scheduler { CompScheduling::Immediate => unreachable!(), CompScheduling::Requeue => { self.runtime.enqueue_work(comp_key); }, CompScheduling::Sleep => { self.mark_component_as_sleeping(comp_key, component); }, - CompScheduling::Exit => { self.mark_component_as_exiting(&scheduler_ctx, component); } + CompScheduling::Exit => { + component.component.on_shutdown(&scheduler_ctx); + self.mark_component_as_exiting(&scheduler_ctx, component); + } } } } diff --git a/src/runtime2/tests/mod.rs b/src/runtime2/tests/mod.rs index b795bc4b6387425f1b393d2029784e0561618d45..fb47364e1adef4283a1a5f4961be67cd7bead285 100644 --- a/src/runtime2/tests/mod.rs +++ b/src/runtime2/tests/mod.rs @@ -247,3 +247,133 @@ fn test_random_u32_temporary_thingo() { let rt = Runtime::new(1, true, pd).unwrap(); create_component(&rt, "", "constructor", no_args()); } + +#[test] +fn test_tcp_socket_http_request() { + let _pd = ProtocolDescription::parse(b" + import std.internet::*; + + primitive requester(out cmd_tx, in data_rx) { + print(\"*** TCPSocket: Sending request\"); + sync { + put(cmd_tx, Cmd::Send(b\"GET / HTTP/1.1\\r\\n\\r\\n\")); + } + + print(\"*** TCPSocket: Receiving response\"); + auto buffer = {}; + auto done_receiving = false; + sync while (!done_receiving) { + put(cmd_tx, Cmd::Receive); + auto data = get(data_rx); + buffer @= data; + + // Completely crap detection of end-of-document. But here we go, we + // try to detect the trailing . Proper way would be to parse + // for 'content-length' or 'content-encoding' + s32 index = 0; + s32 partial_length = cast(length(data) - 7); + while (index < partial_length) { + // No string conversion yet, so check byte buffer one byte at + // a time. + auto c1 = data[index]; + if (c1 == cast('<')) { + auto c2 = data[index + 1]; + auto c3 = data[index + 2]; + auto c4 = data[index + 3]; + auto c5 = data[index + 4]; + auto c6 = data[index + 5]; + auto c7 = data[index + 6]; + if ( // i.e. if (data[index..] == '' + c2 == cast('/') && c3 == cast('h') && c4 == cast('t') && + c5 == cast('m') && c6 == cast('l') && c7 == cast('>') + ) { + print(\"*** TCPSocket: Detected \"); + put(cmd_tx, Cmd::Finish); + done_receiving = true; + } + } + index += 1; + } + } + + print(\"*** TCPSocket: Requesting shutdown\"); + sync { + put(cmd_tx, Cmd::Shutdown); + } + } + + composite main() { + channel cmd_tx -> cmd_rx; + channel data_tx -> data_rx; + new tcp_client({142, 250, 179, 163}, 80, cmd_rx, data_tx); // port 80 of google + new requester(cmd_tx, data_rx); + } + ").expect("compilation"); + + // This test is disabled because it performs a HTTP request to google. + // let rt = Runtime::new(1, true, pd).unwrap(); + // create_component(&rt, "", "main", no_args()); +} + +#[test] +fn test_sending_receiving_union() { + let pd = ProtocolDescription::parse(b" + union Cmd { + Set(u8[]), + Get, + Shutdown, + } + + primitive database(in rx, out tx) { + auto stored = {}; + auto done = false; + while (!done) { + sync { + auto command = get(rx); + if (let Cmd::Set(bytes) = command) { + print(\"database: storing value\"); + stored = bytes; + } else if (let Cmd::Get = command) { + print(\"database: returning value\"); + put(tx, stored); + } else if (let Cmd::Shutdown = command) { + print(\"database: shutting down\"); + done = true; + } else while (true) print(\"impossible\"); // no other case possible + } + } + } + + primitive client(out tx, in rx, u32 num_rounds) { + auto round = 0; + while (round < num_rounds) { + auto set_value = b\"hello there\"; + print(\"client: putting a value\"); + sync put(tx, Cmd::Set(set_value)); + + auto retrieved = {}; + print(\"client: retrieving what was sent\"); + sync { + put(tx, Cmd::Get); + retrieved = get(rx); + } + + if (set_value != retrieved) while (true) print(\"wrong!\"); + + round += 1; + } + + sync put(tx, Cmd::Shutdown); + } + + composite main() { + auto num_rounds = 5; + channel cmd_tx -> cmd_rx; + channel data_tx -> data_rx; + new database(cmd_rx, data_tx); + new client(cmd_tx, data_rx, num_rounds); + } + ").expect("compilation"); + let rt = Runtime::new(1, false, pd).unwrap(); + create_component(&rt, "", "main", no_args()); +} \ No newline at end of file