diff --git a/ast/ast.go b/ast/ast.go index ac2be69..b4b3374 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -64,6 +64,13 @@ type PrefixExpression struct { Right Expression } +type InfixExpression struct { + Token token.Token // operator tokens i.e. +, -, *, / + Left Expression + Operator string + Right Expression +} + // Let Statements func (ls *LetStatement) statement_node() {} @@ -161,3 +168,17 @@ func (pe *PrefixExpression) String() string { return out.String() } + +// Infix Expression +func (ie *InfixExpression) expression_node() {} +func (ie *InfixExpression) TokenLiteral() string { return ie.Token.Literal } +func (ie *InfixExpression) String() string { + var out bytes.Buffer + out.WriteString("(") + out.WriteString(ie.Left.String()) + out.WriteString(" " + ie.Operator + " ") + out.WriteString(ie.Right.String()) + out.WriteString(")") + + return out.String() +} diff --git a/parser/parser.go b/parser/parser.go index 1609751..4ab4d13 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -8,7 +8,7 @@ import ( "strconv" ) -// PRECEDENCE of operations +// Precedence of operations const ( _ int = iota // iota means start from 0, hence _ starts from 0 LOWEST @@ -20,6 +20,32 @@ const ( CALL // simple_function(x) ) +// Precedence Table +var precedences = map[token.TokenType]int{ + token.EQ: EQUALS, + token.NOT_EQ: EQUALS, + token.LT: LESSGREATER, + token.GT: LESSGREATER, + token.PLUS: SUM, + token.MINUS: SUM, + token.SLASH: PRODUCT, + token.ASTERISK: PRODUCT, +} + +func (l_parser *Parser) peek_precedence() int { + if l_parser, ok := precedences[l_parser.peek_token.Type]; ok { + return l_parser + } + return LOWEST +} + +func (l_parser *Parser) current_precedence() int { + if l_parser, ok := precedences[l_parser.current_token.Type]; ok { + return l_parser + } + return LOWEST +} + // Pratt Parsing type ( prefix_parse_function func() ast.Expression @@ -44,12 +70,24 @@ func New(l_lexer *lexer.Lexer) *Parser { l_parser.next_token() l_parser.next_token() + // Prefix Operations l_parser.prefix_parse_functions = make(map[token.TokenType]prefix_parse_function) l_parser.register_prefix(token.IDENT, l_parser.parse_identifier) l_parser.register_prefix(token.INT, l_parser.parse_integer_literal) l_parser.register_prefix(token.BANG, l_parser.parse_prefix_expression) l_parser.register_prefix(token.MINUS, l_parser.parse_prefix_expression) + // Infix Operation + l_parser.infix_parse_functions = make(map[token.TokenType]infix_parse_function) + l_parser.register_infix(token.PLUS, l_parser.parse_infix_expression) + l_parser.register_infix(token.MINUS, l_parser.parse_infix_expression) + l_parser.register_infix(token.SLASH, l_parser.parse_infix_expression) + l_parser.register_infix(token.ASTERISK, l_parser.parse_infix_expression) + l_parser.register_infix(token.EQ, l_parser.parse_infix_expression) + l_parser.register_infix(token.NOT_EQ, l_parser.parse_infix_expression) + l_parser.register_infix(token.LT, l_parser.parse_infix_expression) + l_parser.register_infix(token.GT, l_parser.parse_infix_expression) + return l_parser } @@ -172,6 +210,7 @@ func (l_parser *Parser) parse_integer_literal() ast.Expression { return literal } +// Here lies the heart of Pratt Parsing func (l_parser *Parser) parse_expression(precedence int) ast.Expression { prefix := l_parser.prefix_parse_functions[l_parser.current_token.Type] if prefix == nil { @@ -179,6 +218,16 @@ func (l_parser *Parser) parse_expression(precedence int) ast.Expression { return nil } left_expression := prefix() + + for !l_parser.peek_token_is(token.SEMICOLON) && precedence < l_parser.peek_precedence() { + infix := l_parser.infix_parse_functions[l_parser.peek_token.Type] + if infix == nil { + return left_expression + } + l_parser.next_token() + left_expression = infix(left_expression) + } + return left_expression } @@ -196,3 +245,16 @@ func (l_parser *Parser) parse_prefix_expression() ast.Expression { expression.Right = l_parser.parse_expression(PREFIX) return expression } + +func (l_parser *Parser) parse_infix_expression(left ast.Expression) ast.Expression { + expression := &ast.InfixExpression{ + Token: l_parser.current_token, + Operator: l_parser.current_token.Literal, + Left: left, + } + precedence := l_parser.current_precedence() + l_parser.next_token() + expression.Right = l_parser.parse_expression(precedence) + + return expression +} diff --git a/parser/parser_test.go b/parser/parser_test.go index 337a72b..ce1103e 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -235,3 +235,121 @@ func testIntegerLiteral(l_test *testing.T, il ast.Expression, value int64) bool } return true } + +func TestParsingInfixExpressions(l_test *testing.T) { + infix_tests := []struct { + input string + left_value int64 + operator string + right_value int64 + }{ + {"5 + 5;", 5, "+", 5}, + {"5 - 5;", 5, "-", 5}, + {"5 * 5;", 5, "*", 5}, + {"5 / 5;", 5, "/", 5}, + {"5 > 5;", 5, ">", 5}, + {"5 < 5;", 5, "<", 5}, + {"5 == 5;", 5, "==", 5}, + {"5 != 5;", 5, "!=", 5}, + } + + for _, tt := range infix_tests { + l_lexer := lexer.New(tt.input) + l_parser := New(l_lexer) + program := l_parser.ParseProgram() + check_parser_errors(l_test, l_parser) + + if len(program.Statements) != 1 { + l_test.Fatalf("program.Statements does not contain %d statements, got=%d\n", 1, len(program.Statements)) + } + + statement, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + l_test.Fatalf("program.Statements[0] is not ast.ExpressionStatement, got=%T", program.Statements[0]) + } + + expression, ok := statement.Expression.(*ast.InfixExpression) + if !ok { + l_test.Fatalf("expression is not ast.InfixExpression, got=%T", statement.Expression) + } + + if !testIntegerLiteral(l_test, expression.Left, tt.left_value) { + return + } + + if expression.Operator != tt.operator { + l_test.Fatalf("expression.Operator is not '%s', got=%s", tt.operator, expression.Operator) + } + + if !testIntegerLiteral(l_test, expression.Right, tt.right_value) { + return + } + } +} + +func TestOperatorPrecedenceParsing(l_test *testing.T) { + tests := []struct { + input string + expected string + }{ + { + "-a * b", + "((-a) * b)", + }, + { + "!-a", + "(!(-a))", + }, + { + "a + b + c", + "((a + b) + c)", + }, + { + "a + b - c", + "((a + b) - c)", + }, + { + "a * b * c", + "((a * b) * c)", + }, + { + "a * b / c", + "((a * b) / c)", + }, + { + "a + b / c", + "(a + (b / c))", + }, + { + "a + b * c + d / e - f", + "(((a + (b * c)) + (d / e)) - f)", + }, + { + "3 + 4; -5 * 5", + "(3 + 4)((-5) * 5)", + }, + { + "5 > 4 == 3 < 4", + "((5 > 4) == (3 < 4))", + }, + { + "5 < 4 != 3 > 4", + "((5 < 4) != (3 > 4))", + }, + { + "3 + 4 * 5 == 3 * 1 + 4 * 5", + "((3 + (4 * 5)) == ((3 * 1) + (4 * 5)))", + }, + } + for _, tt := range tests { + l_lexer := lexer.New(tt.input) + l_parser := New(l_lexer) + program := l_parser.ParseProgram() + check_parser_errors(l_test, l_parser) + + actual := program.String() + if actual != tt.expected { + l_test.Errorf("expected=%q, got=%q", tt.expected, actual) + } + } +}