diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index a62348f..42812cd 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -69,6 +69,22 @@ func Eval(node ast.Node, env *object.Environment) object.Object { case *ast.Identifier: return eval_identifier(node, env) + + case *ast.FunctionLiteral: + params := node.Parameters + body := node.Body + return &object.Function{Parameters: params, Env: env, Body: body} + + case *ast.CallExpression: + function := Eval(node.Function, env) + if is_error(function) { + return function + } + args := eval_expression(node.Arguments, env) + if len(args) == 1 && is_error(args[0]) { + return args[0] + } + return apply_function(function, args) } return nil @@ -91,6 +107,46 @@ func eval_program(program *ast.Program, env *object.Environment) object.Object { return result } +func apply_function(fn object.Object, args []object.Object) object.Object { + function, ok := fn.(*object.Function) + if !ok { + return new_error("not a function: %s", fn.Type()) + } + + extended_env := extend_function_env(function, args) + evaluated := Eval(function.Body, extended_env) + return unwrap_return_value(evaluated) +} + +func extend_function_env(fn *object.Function, args []object.Object) *object.Environment { + env := object.NewEnclosedEnvironment(fn.Env) + + for param_index, param := range fn.Parameters { + env.Set(param.Value, args[param_index]) + } + return env +} + +func unwrap_return_value(obj object.Object) object.Object { + if return_value, ok := obj.(*object.ReturnValue); ok { + return return_value.Value + } + return obj +} + +func eval_expression(expressions []ast.Expression, env *object.Environment) []object.Object { + var result []object.Object + + for _, e := range expressions { + evaluated := Eval(e, env) + if is_error(evaluated) { + return []object.Object{evaluated} + } + result = append(result, evaluated) + } + return result +} + func eval_identifier(node *ast.Identifier, env *object.Environment) object.Object { val, ok := env.Get(node.Value) if !ok { diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index 80e0f51..a34568f 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -184,6 +184,59 @@ func TestLetStatements(l_test *testing.T) { } } +func TestFunctionObject(l_test *testing.T) { + input := "fn(x) { x + 2;};" + evaluated := test_eval(input) + + fn, ok := evaluated.(*object.Function) + if !ok { + l_test.Fatalf("object is not Function, got=%T (%+v)", evaluated, evaluated) + } + + if len(fn.Parameters) != 1 { + l_test.Fatalf("function has wrong parameters, Parameters=%+v", fn.Parameters) + } + + if fn.Parameters[0].String() != "x" { + l_test.Fatalf("parameter is not 'x', got=%q", fn.Parameters[0]) + } + + expected_body := "(x + 2)" + + if fn.Body.String() != expected_body { + l_test.Fatalf("body is not %q, got=%q", expected_body, fn.Body.String()) + } +} + +func TestFunctionApplication(l_test *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"let identity = fn(x) { x; }; identity(5);", 5}, + {"let identity = fn(x) { return x; }; identity(5);", 5}, + {"let double = fn(x) { x * 2; }; double(5);", 10}, + {"let add = fn(x, y) { x + y; }; add(5, 5);", 10}, + {"let add = fn(x, y) { x + y; }; add(5 + 5, add(5, 5));", 20}, + {"fn(x) { x; }(5)", 5}, + } + + for _, tt := range tests { + test_integer_object(l_test, test_eval(tt.input), tt.expected) + } +} + +func TestClosures(l_test *testing.T) { + input := ` + let newAdder = fn(x) { + fn(y) { x + y }; + }; + let addTwo = newAdder(2); + addTwo(2); +` + test_integer_object(l_test, test_eval(input), 4) +} + // Helpers func test_eval(input string) object.Object { l_lexer := lexer.New(input) diff --git a/object/environment.go b/object/environment.go index 9b89409..851ba5b 100644 --- a/object/environment.go +++ b/object/environment.go @@ -9,15 +9,53 @@ package object type Environment struct { store map[string]Object + outer *Environment } func NewEnvironment() *Environment { s := make(map[string]Object) - return &Environment{store: s} + return &Environment{store: s, outer: nil} +} + +/* + Enclosing Environments + + Here is a problem case, lets say in monkey I would want to type this: + + ``` + let i = 5; + let print_num = fn(i) { + puts(i); + } + + print_num(10); + puts(i); + ``` + + The ideal result of the above code in the monkey programming language is for 10 and 5 to be the outputs respectively. + In a situation where enclosed environment does not exists, both outputs will be 10 because the current value of i + would be overwritten. The ideal situation would be to preserve the previous binding to 'i' while also making a a new + one. + + This works be creating a new instance of object.Environment with a pointer to the environment it should extend, doing this + encloses a fresh and empty environment with an existing one. When the Get method is called and it itself doesn't have the value + associated with the given name, it calls the Get of the enclosing environment. That's the environment it's extending. If that + enclosing environment can't find the value, it calls its own enclosing environment and so on until there is no enclosing environment + anymore and it will error out to an unknown identifier. +*/ +func NewEnclosedEnvironment(outer *Environment) *Environment { + env := NewEnvironment() + env.outer = outer + return env } func (l_environment *Environment) Get(name string) (Object, bool) { obj, ok := l_environment.store[name] + + if !ok && l_environment.outer != nil { + obj, ok = l_environment.outer.Get(name) + } + return obj, ok } diff --git a/object/object.go b/object/object.go index d894232..171c83d 100644 --- a/object/object.go +++ b/object/object.go @@ -1,6 +1,11 @@ package object -import "fmt" +import ( + "bytes" + "fmt" + "monkey/ast" + "strings" +) type ObjectType string @@ -10,6 +15,7 @@ const ( NULL_OBJECT = "NULL" RETURN_VALUE_OBJECT = "RETURN_VALUE" ERROR_OBJECT = "ERROR" + FUNCTION_OBJECT = "FUNCTION" ) type Object interface { @@ -69,3 +75,29 @@ type Error struct { func (err *Error) Type() ObjectType { return ERROR_OBJECT } func (err *Error) Inspect() string { return "ERROR: " + err.Message } + +// Function +type Function struct { + Parameters []*ast.Identifier + Body *ast.BlockStatement + Env *Environment +} + +func (f *Function) Type() ObjectType { return FUNCTION_OBJECT } +func (f *Function) Inspect() string { + var out bytes.Buffer + + params := []string{} + for _, p := range f.Parameters { + params = append(params, p.String()) + } + + out.WriteString("fn") + out.WriteString("(") + out.WriteString(strings.Join(params, ", ")) + out.WriteString(") {\n") + out.WriteString(f.Body.String()) + out.WriteString("\n}") + + return out.String() +}