diff --git a/ast/ast.go b/ast/ast.go index 69d37af..612849f 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -26,11 +26,16 @@ type Identifier struct { } type LetStatement struct { - Token token.Token // the token.LET token + Token token.Token // token.LET token Name *Identifier Value Expression } +type ReturnStatement struct { + Token token.Token // token.RETURN token + ReturnValue Expression +} + func (ls *LetStatement) statement_node() {} func (ls *LetStatement) TokenLiteral() string { @@ -50,3 +55,9 @@ func (p *Program) TokenLiteral() string { return "" } } + +func (rs *ReturnStatement) statement_node() {} + +func (rs *ReturnStatement) TokenLiteral() string { + return rs.Token.Literal +} diff --git a/parser/parser.go b/parser/parser.go index 38cfee3..a9f1c98 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -33,7 +33,7 @@ func (l_parser *Parser) ParseProgram() *ast.Program { program := &ast.Program{} program.Statements = []ast.Statement{} - for !l_parser.current_token_is(token.EOF){ + for !l_parser.current_token_is(token.EOF) { statement := l_parser.parse_statement() if statement != nil { program.Statements = append(program.Statements, statement) @@ -52,6 +52,8 @@ func (l_parser *Parser) parse_statement() ast.Statement { switch l_parser.current_token.Type { case token.LET: return l_parser.parse_let_statement() + case token.RETURN: + return l_parser.parse_return_statement() default: return nil } @@ -75,6 +77,17 @@ func (l_parser *Parser) parse_let_statement() *ast.LetStatement { return statement } +func (l_parser *Parser) parse_return_statement() *ast.ReturnStatement { + statement := &ast.ReturnStatement{Token: l_parser.current_token} + l_parser.next_token() + + // TODO(tijani): Skipping the expression until there is semicolon + for !l_parser.current_token_is(token.SEMICOLON) { + l_parser.next_token() + } + return statement +} + func (l_parser *Parser) expect_peek(l_token token.TokenType) bool { if l_parser.peek_token_is(l_token) { l_parser.next_token() diff --git a/parser/parser_test.go b/parser/parser_test.go index 5cbc3de..0e81f0c 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -6,6 +6,20 @@ import ( "testing" ) +func check_parser_errors(l_test *testing.T, l_parser *Parser) { + errors := l_parser.Errors() + if len(errors) == 0 { + return + } + + l_test.Errorf("parser has %d errors", len(errors)) + + for _, message := range errors { + l_test.Errorf("parser error: %q", message) + } + l_test.FailNow() +} + func TestLetStatement(l_test *testing.T) { input := ` let x = 4; @@ -68,16 +82,33 @@ func testLetStatement(l_test *testing.T, statement ast.Statement, name string) b return true } -func check_parser_errors(l_test *testing.T, l_parser *Parser) { - errors := l_parser.Errors() - if len(errors) == 0 { - return +func TestReturnStatement(l_test *testing.T) { + input := ` + return 6; + return 10; + return 8419849; + ` + + l_lexer := lexer.New(input) + l_parser := New(l_lexer) + + program := l_parser.ParseProgram() + check_parser_errors(l_test, l_parser) + + if len(program.Statements) != 3 { + l_test.Fatalf("program.Statements does not contain 3 statements, got=%d", len(program.Statements)) } - l_test.Errorf("parser has %d errors", len(errors)) - - for _, message := range errors { - l_test.Errorf("parser error: %q", message) + for _, statement := range program.Statements { + return_statement, ok := statement.(*ast.ReturnStatement) + + if !ok { + l_test.Errorf("statment not *ast.ReturnStatement, got =%T", statement) + continue + } + + if return_statement.TokenLiteral() != "return" { + l_test.Errorf("return_statement.TokenLiteral() not 'return', got %q", return_statement.TokenLiteral()) + } } - l_test.FailNow() }