package scratch;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Function;

import org.junit.jupiter.api.Test;

public class Continuations {

	// Here's a simple AST for expressions with constants, +, and *
	abstract static class Expression<T> {
		public abstract <V> V visit(SimpleExpressionVisitor<T, V> visitor);
	}

	static class Constant<T> extends Expression<T> {
		T value;

		public Constant(T value) {
			this.value = value;
		}

		public <V> V visit(SimpleExpressionVisitor<T, V> visitor) {
			return visitor.visit(this, value);
		}
	}

	static class Plus<T> extends Expression<T> {
		Expression<T> left, right;

		public Plus(Expression<T> left, Expression<T> right) {
			this.left = left;
			this.right = right;
		}

		public <V> V visit(SimpleExpressionVisitor<T, V> visitor) {
			return visitor.visit(this, left.visit(visitor), right.visit(visitor));
		}
	}

	static class Times<T> extends Expression<T> {
		Expression<T> left, right;

		public Times(Expression<T> left, Expression<T> right) {
			this.left = left;
			this.right = right;
		}

		public <V> V visit(SimpleExpressionVisitor<T, V> visitor) {
			return visitor.visit(this, left.visit(visitor), right.visit(visitor));
		}
	}

	// We create a simple visitor for them
	private static interface SimpleExpressionVisitor<T, V> {
		V visit(Constant<T> constant, T value);

		V visit(Plus<T> plus, V left, V right);

		V visit(Times<T> times, V left, V right);
	}

	// We can evalaute them as integers
	static class IntegerAlgebra implements SimpleExpressionVisitor<Integer, Integer> {
		@Override
		public Integer visit(Constant<Integer> constant, Integer value) {
//			System.out.println("visit constant");
			return value;
		}

		@Override
		public Integer visit(Plus<Integer> plus, Integer left, Integer right) {
//			System.out.println("visit plus");
			return left + right;
		}

		@Override
		public Integer visit(Times<Integer> times, Integer left, Integer right) {
//			System.out.println("visit times");
			return left * right;
		}
	}

	// Or evaluate them as booleans
	private class BooleanAlgebra implements SimpleExpressionVisitor<Boolean, Boolean> {
		@Override
		public Boolean visit(Constant<Boolean> constant, Boolean value) {
			return value;
		}

		@Override
		public Boolean visit(Plus<Boolean> plus, Boolean left, Boolean right) {
			return left || right;
		}

		@Override
		public Boolean visit(Times<Boolean> times, Boolean left, Boolean right) {
			return left && right;
		}
	}

	@Test
	void testSimple() {
		// We can create and traverse expression of different types
		Expression<Integer> integerExpression = new Plus<>(new Times<>(new Constant<>(5), new Constant<>(7)),
				new Constant<>(3));
		System.out.println(integerExpression.visit(new IntegerAlgebra()));

		Expression<Boolean> booleanExpression = new Plus<>(new Times<>(new Constant<>(true), new Constant<>(false)),
				new Constant<>(true));
		System.out.println(booleanExpression.visit(new BooleanAlgebra()));
	}

	// we can also do other things using the generic traversal
	static class SimpleToString<T> implements SimpleExpressionVisitor<T, String> {
		@Override
		public String visit(Constant<T> constant, T value) {
			return Objects.toString(value);
		}

		@Override
		public String visit(Plus<T> plus, String left, String right) {
			return "(" + left + " + " + right + ")";
		}

		@Override
		public String visit(Times<T> times, String left, String right) {
			return "(" + left + " * " + right + ")";
		}
	}

	@Test
	void testSimpleToString() {
		// We can traverse expressions in different ways
		Expression<Integer> integerExpression = new Plus<>(new Times<>(new Constant<>(5), new Constant<>(7)),
				new Constant<>(3));
		System.out.println(integerExpression.visit(new SimpleToString<>()));

		Expression<Boolean> booleanExpression = new Plus<>(new Times<>(new Constant<>(true), new Constant<>(false)),
				new Constant<>(true));
		System.out.println(booleanExpression.visit(new SimpleToString<>()));
	}

	// We can also create a different implementation of ToString using a
	// continuation
	static class FancyToString<T>
			implements SimpleExpressionVisitor<T, Function<StringBuilder, StringBuilder>> {
		@Override
		public Function<StringBuilder, StringBuilder> visit(Constant<T> constant, T value) {
			return sb -> {
				sb.append(value);
				return sb;
			};
		}

		@Override
		public Function<StringBuilder, StringBuilder> visit(Plus<T> plus, Function<StringBuilder, StringBuilder> left,
				Function<StringBuilder, StringBuilder> right) {
			return sb -> {
				sb.append('(');
				sb = left.apply(sb);
				sb.append(" + ");
				sb = right.apply(sb);
				sb.append(')');
				return sb;
			};
		}

		@Override
		public Function<StringBuilder, StringBuilder> visit(Times<T> times, Function<StringBuilder, StringBuilder> left,
				Function<StringBuilder, StringBuilder> right) {
			return sb -> {
				sb.append('(');
				sb = left.apply(sb);
				sb.append(" * ");
				sb = right.apply(sb);
				sb.append(')');
				return sb;
			};
		}
	}

	@Test
	void testFancyToString() {
		// What's the point of the fancy version?
		Expression<Integer> integerExpression = new Plus<>(new Times<>(new Constant<>(5), new Constant<>(7)),
				new Constant<>(3));
		System.out.println(integerExpression.visit(new FancyToString<>()).apply(new StringBuilder()).toString());

		Expression<Boolean> booleanExpression = new Plus<>(new Times<>(new Constant<>(true), new Constant<>(false)),
				new Constant<>(true));
		System.out.println(booleanExpression.visit(new FancyToString<>()).apply(new StringBuilder()).toString());
	}

	Expression<Integer> build(int size) {
		Expression<Integer> result = new Constant<>(0);
		for (int i = 1; i < size; i++) {
			result = new Plus<>(new Constant<>(i), result);
		}
		return result;
	}

	@Test
	void testCompareToStrings() {
		// What's the point of the fancy version? (you can probably not increase the
		// parameter to build without increasing the Java stack size)
		Expression<Integer> integerExpression = build(4000);
		long before = System.currentTimeMillis();
		integerExpression.visit(new SimpleToString<>());
		System.out.println("simple: " + (System.currentTimeMillis() - before));

		before = System.currentTimeMillis();
		integerExpression.visit(new FancyToString<>()).apply(new StringBuilder()).toString();
		System.out.println("fancy: " + (System.currentTimeMillis() - before));
	}

	// We can build funcctions to achieve various goals. Fx
	static class FunctionMap {
		public static <K, V> Function<K, V> empty() {
			return k -> {
				throw new IllegalArgumentException();
			};
		}

		public static <K, V> Function<K, V> put(Function<K, V> map, K key, V value) {
			return k -> Objects.equals(key, k) ? value : map.apply(k);
		}

		public static <K, V> Function<K, V> remove(Function<K, V> map, K key) {
			return k -> {
				if (Objects.equals(key, k)) {
					throw new IllegalArgumentException();
				} else {
					return map.apply(k);
				}
			};
		}
	}

	@Test
	void testMap() {
		// It works like a Map, e.g., a HashMap (just less efficiently for large maps)
		Function<String, Integer> map = FunctionMap.empty();

		assertThrows(IllegalArgumentException.class, () -> map.apply("hello"));

		Function<String, Integer> map1 = FunctionMap.put(map, "hello", 7);
		assertEquals(7, map1.apply("hello"));

		Function<String, Integer> map2 = FunctionMap.put(map1, "world", 5);
		assertEquals(7, map2.apply("hello"));
		assertEquals(5, map2.apply("world"));

		// It can do something HashMaps cannot:
		Function<String, Integer> map3 = FunctionMap.put(map2, "hello", 3);
		assertEquals(3, map3.apply("hello"));
		assertEquals(5, map3.apply("world"));
		assertEquals(7, map2.apply("hello"));
		assertEquals(5, map2.apply("world"));

		// Computationally, a HashMap can insert and look up values in constant time and
		// copy a map in linear time
		// The FunctionMap can insert and copy a map in constant time and look up in
		// linear time
	}

	static class Pair<A, B> {
		private final A a;
		private final B b;

		public Pair(A a, B b) {
			this.a = a;
			this.b = b;
		}

		public A getA() {
			return a;
		}

		public B getB() {
			return b;
		}
	}

	static class FunctionalList<E> extends Pair<FunctionalList<E>, E> {
		public FunctionalList(FunctionalList<E> a, E b) {
			super(a, b);
		}
	}

	static class FunctionList {
		public static <E> FunctionalList<E> empty() {
			return new FunctionalList<E>(null, null) {
				public FunctionalList<E> getA() {
					throw new IllegalArgumentException();
				}

				public E getB() {
					throw new IllegalArgumentException();
				}
			};
		}

		public static <E> FunctionalList<E> add(FunctionalList<E> list, E element) {
			return new FunctionalList<E>(list, element);
		}

		public static <E> Pair<FunctionalList<E>, E> remove(FunctionalList<E> list) {
			return new Pair<FunctionalList<E>, E>(list.getA(), list.getB());
		}
	}

	@Test
	void testList() {
		FunctionalList<Integer> empty = FunctionList.empty();

		FunctionalList<Integer> list = FunctionList.add(empty, 3);
		assertEquals(3, FunctionList.remove(list).getB());

		FunctionalList<Integer> list1 = FunctionList.add(empty, 5);
		FunctionalList<Integer> list2 = FunctionList.add(list1, 7);
		Pair<FunctionalList<Integer>, Integer> result = FunctionList.remove(list2);
		assertEquals(7, result.getB());
		assertEquals(5, FunctionList.remove(result.getA()).getB());
		assertEquals(5, FunctionList.remove(list1).getB());

		FunctionalList<Integer> list3 = FunctionList.add(list1, 11);
		assertEquals(11, FunctionList.remove(list3).getB());
		assertEquals(7, FunctionList.remove(list2).getB());
		assertEquals(5, FunctionList.remove(list1).getB());
	}

	// Now, let's extend our expressions with valriables. We allow simple variable
	// assignments and using variables
	private static abstract class FancyExpression<T> extends Expression<T> {
		public abstract <V> V visit(ExpressionVisitor<T, V> visitor);

		@Override
		public <V> V visit(SimpleExpressionVisitor<T, V> visitor) {
			if (visitor instanceof ExpressionVisitor<?, ?>) {
				return visit((ExpressionVisitor<T, V>) visitor);
			}
			throw new IllegalArgumentException();
		}
	}

	static class Variable<T> extends FancyExpression<T> {
		String name;

		public Variable(String name) {
			this.name = name;
		}

		@Override
		public <V> V visit(ExpressionVisitor<T, V> visitor) {
			return visitor.visit(this, name);
		}
	}

	static class Assignment<T> extends FancyExpression<T> {
		String name;
		Expression<T> value, next;

		public Assignment(String name, Expression<T> value, Expression<T> next) {
			this.name = name;
			this.value = value;
			this.next = next;
		}

		@Override
		public <V> V visit(ExpressionVisitor<T, V> visitor) {
			return visitor.visit(this, name, value.visit(visitor), next.visit(visitor));
		}
	}

	// We need a fancier visitor for these new types
	private static interface ExpressionVisitor<T, V> extends SimpleExpressionVisitor<T, V> {
		V visit(Variable<T> variable, String name);

		V visit(Assignment<T> assignment, String name, V value, V next);
	}

	// And implement an algebra
	static class FancyIntegerAlgebra
			implements ExpressionVisitor<Integer, Function<Function<String, Integer>, Integer>> {

		@Override
		public Function<Function<String, Integer>, Integer> visit(Constant<Integer> constant, Integer value) {
//			System.out.println("visit constant");
			return env -> {
//				System.out.println("evaluate constant");
				return value;
			};
		}

		@Override
		public Function<Function<String, Integer>, Integer> visit(Plus<Integer> plus,
				Function<Function<String, Integer>, Integer> left, Function<Function<String, Integer>, Integer> right) {
//			System.out.println("visit plus");
			return env -> {
//				System.out.println("evaluate plus");
				return left.apply(env) + right.apply(env);
			};
		}

		@Override
		public Function<Function<String, Integer>, Integer> visit(Times<Integer> times,
				Function<Function<String, Integer>, Integer> left, Function<Function<String, Integer>, Integer> right) {
//			System.out.println("visit times");
			return env -> {
//				System.out.println("evaluate times");
				return left.apply(env) * right.apply(env);
			};
		}

		@Override
		public Function<Function<String, Integer>, Integer> visit(Variable<Integer> variable, String name) {
//			System.out.println("visit variable");
			return env -> {
//				System.out.println("evaluate variable");
				return env.apply(name); // Here, we look up the variable in the environment and return the result
			};
		}

		@Override
		public Function<Function<String, Integer>, Integer> visit(Assignment<Integer> assignment, String name,
				Function<Function<String, Integer>, Integer> value, Function<Function<String, Integer>, Integer> next) {
//			System.out.println("visit assignment");
			// This is the fancy bit; we evaluate value using the environment we are given
			// and create a new environment to evaluate the next bit containing the value of
			// ther expression bound to the value of the evaluation of the value expression
			return env -> {
//				System.out.println("evaluate assignment");
				return next.apply(FunctionMap.put(env, name, value.apply(env)));
			};
		}
	}

	static class FancyExtendedToString<T> extends FancyToString<T>
			implements ExpressionVisitor<T, Function<StringBuilder, StringBuilder>> {
		@Override
		public Function<StringBuilder, StringBuilder> visit(Variable<T> variable, String name) {
			return sb -> {
				sb.append(name);
				return sb;
			};
		}

		@Override
		public Function<StringBuilder, StringBuilder> visit(Assignment<T> assignment, String name,
				Function<StringBuilder, StringBuilder> value, Function<StringBuilder, StringBuilder> next) {
			return sb -> {
				sb.append('(');
				sb.append(name);
				sb.append(" := ");
				sb = value.apply(sb);
				sb.append("; ");
				sb = next.apply(sb);
				sb.append(')');
				return sb;
			};
		}

	}

	@Test
	void testFancy() {
		// For simple expressions this works as before
		Expression<Integer> integerExpression = new Plus<>(new Times<>(new Constant<>(5), new Constant<>(7)),
				new Constant<>(3));
		System.out
				.println(integerExpression.visit(new FancyExtendedToString<>()).apply(new StringBuilder()).toString());
		System.out.println(integerExpression.visit(new FancyIntegerAlgebra()).apply(FunctionMap.empty()));

		// We can now create variables
		Expression<Integer> variableExpression = new Assignment<>("a", new Constant<>(5),
				new Plus<>(new Variable<>("a"), new Constant<>(7)));
		System.out
				.println(variableExpression.visit(new FancyExtendedToString<>()).apply(new StringBuilder()).toString());
		System.out.println(variableExpression.visit(new FancyIntegerAlgebra()).apply(FunctionMap.empty()));

		// Variables can be pretty fancy
		Expression<Integer> fancyVariableExpression = new Assignment<>("a",
				new Plus<>(new Times<>(new Constant<>(5), new Constant<>(3)), new Constant<>(11)),
				new Plus<>(new Variable<>("a"), new Constant<>(7)));
		System.out.println(
				fancyVariableExpression.visit(new FancyExtendedToString<>()).apply(new StringBuilder()).toString());
		System.out.println(fancyVariableExpression.visit(new FancyIntegerAlgebra()).apply(FunctionMap.empty()));

		// We can have more variables and reuse variables in scopes and it Just
		// Works(tm)
		Expression<Integer> crazyVariableExpression = new Assignment<>("b",
				new Assignment<>("a",
						new Plus<>(
								new Times<>(new Constant<>(5),
										new Assignment<>("a", new Plus<>(new Constant<>(3), new Constant<>(17)),
												new Plus<>(new Variable<>("a"), new Constant<>(19)))),
								new Constant<>(11)),
						new Plus<>(new Variable<>("a"), new Constant<>(7))),
				new Plus<>(new Variable<>("b"), new Constant<>(13)));
		System.out.println(
				crazyVariableExpression.visit(new FancyExtendedToString<>()).apply(new StringBuilder()).toString());
		System.out.println(crazyVariableExpression.visit(new FancyIntegerAlgebra()).apply(FunctionMap.empty()));

		// Consider the difference between these and compare to intepreattion vs JIT
		// compilation
		System.out.println("Starting simple evaluation");
		Integer simpleEvaluation = integerExpression.visit(new IntegerAlgebra());
		System.out.println("Simple evaluation done");
		System.out.println("Simple result: " + simpleEvaluation);

		System.out.println("Starting fancy evaluation");
		Function<Function<String, Integer>, Integer> fancyEvaluation = integerExpression
				.visit(new FancyIntegerAlgebra());
		System.out.println("Fancy evaluation done");
		System.out.println("Fancy result: " + fancyEvaluation.apply(FunctionMap.empty()));
		System.out.println("Fancy result again: " + fancyEvaluation.apply(FunctionMap.empty()));

		// How about this one?
		Expression<Integer> squareExpression = new Times<>(new Variable<>("a"), new Variable<>("a"));
		System.out.println("Starting square evaluation");
		Function<Function<String, Integer>, Integer> squareEvaluation = squareExpression
				.visit(new FancyIntegerAlgebra());
		System.out.println("Square evaluation done");
		assertThrows(IllegalArgumentException.class, () -> squareEvaluation.apply(FunctionMap.empty()));
		System.out
				.println("Square result of 5: " + squareEvaluation.apply(FunctionMap.put(FunctionMap.empty(), "a", 5)));
		System.out
				.println("Square result of 7: " + squareEvaluation.apply(FunctionMap.put(FunctionMap.empty(), "a", 7)));
	}

	@Test
	void testSpeedJIT() {
		Expression<Integer> integerExpression = new Plus<>(new Times<>(new Constant<>(5), new Constant<>(7)),
				new Constant<>(3));

		long before = System.currentTimeMillis();
		for (int i = 0; i < 1000000000; i++) {
			integerExpression.visit(new IntegerAlgebra());
		}
		integerExpression.visit(new SimpleToString<>());
		System.out.println("interpreted: " + (System.currentTimeMillis() - before));

		before = System.currentTimeMillis();
		Function<Function<String, Integer>, Integer> fancyEvaluation = integerExpression
				.visit(new FancyIntegerAlgebra());
		for (int i = 0; i < 1000000000; i++) {
			fancyEvaluation.apply(FunctionMap.empty());
		}
		integerExpression.visit(new SimpleToString<>());
		System.out.println("JIT compiled: " + (System.currentTimeMillis() - before));

	}

	// We can make this extension:
	static class Square<T> extends Expression<T> {
		Expression<T> expression;

		public Square(Expression<T> expression) {
			this.expression = expression;
		}

		public <V> V visit(FancyExpressionVisitor<T, V> visitor) {
			return visitor.visit(this, expression.visit(visitor));
		}

		@Override
		public <V> V visit(SimpleExpressionVisitor<T, V> visitor) {
			if (visitor instanceof FancyExpressionVisitor<?, ?>) {
				return visit((FancyExpressionVisitor<T, V>) visitor);
			}
			throw new IllegalArgumentException();
		}

	}

	private static interface FancyExpressionVisitor<T, V> extends ExpressionVisitor<T, V> {
		V visit(Square<T> square, V expression);
	}

	static class FancierExtendedToString<T> extends FancyExtendedToString<T>
			implements FancyExpressionVisitor<T, Function<StringBuilder, StringBuilder>> {

		@Override
		public Function<StringBuilder, StringBuilder> visit(Square<T> square,
				Function<StringBuilder, StringBuilder> expression) {
			return sb -> {
				sb.append('(');
				sb = expression.apply(sb);
				sb.append("^2)");
				return sb;
			};
		}
	}

	// Instead of creating a new algebra, we can just do some transofrmations:
	static class SquareDesugar<T> implements FancyExpressionVisitor<T, Expression<T>> {

		@Override
		public Expression<T> visit(Variable<T> variable, String name) {
			return new Variable<>(name);
		}

		@Override
		public Expression<T> visit(Assignment<T> assignment, String name, Expression<T> value, Expression<T> next) {
			return new Assignment<>(name, value, next);
		}

		@Override
		public Expression<T> visit(Constant<T> constant, T value) {
			return new Constant<>(value);
		}

		@Override
		public Expression<T> visit(Plus<T> plus, Expression<T> left, Expression<T> right) {
			return new Plus<>(left, right);
		}

		@Override
		public Expression<T> visit(Times<T> times, Expression<T> left, Expression<T> right) {
			return new Times<>(left, right);
		}

		@Override
		public Expression<T> visit(Square<T> square, Expression<T> expression) {
			return new Assignment<>("s", expression, new Times<>(new Variable<>("s"), new Variable<>("s")));
		}
	}

	@Test
	void testDesugaring() {
		Expression<Integer> integerExpression = new Square<>(
				new Plus<>(new Times<>(new Constant<>(5), new Constant<>(7)), new Constant<>(3)));
		System.out.println("Original expression: "
				+ integerExpression.visit(new FancierExtendedToString<>()).apply(new StringBuilder()).toString());
		Expression<Integer> desugaredExpression = integerExpression.visit(new SquareDesugar<>());
		System.out.println("Desugared expression: "
				+ desugaredExpression.visit(new FancierExtendedToString<>()).apply(new StringBuilder()).toString());
		// If you inspect the trace of the visit, you see we only evaluate the original
		// expression once
		System.out.println(desugaredExpression.visit(new FancyIntegerAlgebra()).apply(FunctionMap.empty()));
		// Why is this?
		assertThrows(IllegalArgumentException.class, () -> integerExpression.visit(new FancyIntegerAlgebra()));

		// Fun task: try extending the expression with functions. You need a function
		// definition and a function application like we did for variables. For
		// simplicity, only allow functions with one parameter. Try implementing
		// evaluation using both desugaring like for the Square node (that's simpler,
		// but will result in duplicating the funcction code if a funcction is called
		// more than once and cannot work with recursion), and using a special
		// evaluator; you probably want the visitor to take a parameter like
		// BiFunction<Function<String, Integer>, Function<String, Function<String,
		// Integer>>, Integer> with the new parameter keeping track of registered
		// functions like we did for keeping track of assigned variables.
	}

	private static abstract class FunctionExpression<T> extends Expression<T> {
		public abstract <V> V visit(FunctionExpressionVisitor<T, V> visitor);

		@Override
		public <V> V visit(SimpleExpressionVisitor<T, V> visitor) {
			if (visitor instanceof FunctionExpressionVisitor<?, ?>) {
				return visit((FunctionExpressionVisitor<T, V>) visitor);
			}
			throw new IllegalArgumentException();
		}
	}

	static class FunctionDefinition<T> extends FunctionExpression<T> {
		String name;
		Expression<T> body;
		Expression<T> next;

		public FunctionDefinition(String name, Expression<T> body, Expression<T> next) {
			this.name = name;
			this.body = body;
			this.next = next;
		}

		public <V> V visit(FunctionExpressionVisitor<T, V> visitor) {
			return visitor.visit(this, name, body.visit(visitor), next.visit(visitor));
		}
	}

	static class FunctionApplication<T> extends FunctionExpression<T> {
		String name;
		Expression<T> parameter;

		public FunctionApplication(String name, Expression<T> parameter) {
			this.name = name;
			this.parameter = parameter;
		}

		public <V> V visit(FunctionExpressionVisitor<T, V> visitor) {
			return visitor.visit(this, name, parameter.visit(visitor));
		}
	}

	private static interface FunctionExpressionVisitor<T, V> extends ExpressionVisitor<T, V> {
		V visit(FunctionDefinition<T> functionDefinition, String name, V body, V next);

		V visit(FunctionApplication<T> functionApplication, String name, V parameter);
	}

	static class FunctionDesugar<T>
			implements FunctionExpressionVisitor<T, Function<Function<String, Expression<T>>, Expression<T>>> {

		@Override
		public Function<Function<String, Expression<T>>, Expression<T>> visit(Variable<T> variable, String name) {
			return e -> new Variable<>(name);
		}

		@Override
		public Function<Function<String, Expression<T>>, Expression<T>> visit(Assignment<T> assignment, String name,
				Function<Function<String, Expression<T>>, Expression<T>> value,
				Function<Function<String, Expression<T>>, Expression<T>> next) {
			return e -> new Assignment<>(name, value.apply(e), next.apply(e));
		}

		@Override
		public Function<Function<String, Expression<T>>, Expression<T>> visit(Constant<T> constant, T value) {
			return e -> new Constant<>(value);
		}

		@Override
		public Function<Function<String, Expression<T>>, Expression<T>> visit(Plus<T> plus,
				Function<Function<String, Expression<T>>, Expression<T>> left,
				Function<Function<String, Expression<T>>, Expression<T>> right) {
			return e -> new Plus<>(left.apply(e), right.apply(e));
		}

		@Override
		public Function<Function<String, Expression<T>>, Expression<T>> visit(Times<T> times,
				Function<Function<String, Expression<T>>, Expression<T>> left,
				Function<Function<String, Expression<T>>, Expression<T>> right) {
			return e -> new Times<>(left.apply(e), right.apply(e));
		}

		@Override
		public Function<Function<String, Expression<T>>, Expression<T>> visit(FunctionDefinition<T> functionDefinition,
				String name, Function<Function<String, Expression<T>>, Expression<T>> body,
				Function<Function<String, Expression<T>>, Expression<T>> next) {
			return e -> next.apply(FunctionMap.put(e, name, body.apply(e)));
		}

		@Override
		public Function<Function<String, Expression<T>>, Expression<T>> visit(
				FunctionApplication<T> functionApplication, String name,
				Function<Function<String, Expression<T>>, Expression<T>> parameter) {
			return e -> new Assignment<>("p", parameter.apply(e), e.apply(name));
		}
	}

	@Test
	void testFunctionDesugaring() {
		FunctionDefinition<Integer> functionExpression = new FunctionDefinition<>("square",
				new Times<>(new Variable<>("p"), new Variable<>("p")),
				new FunctionApplication<>("square", new Constant<>(5)));

		Expression<Integer> desugaredExpression = functionExpression.visit(new FunctionDesugar<>())
				.apply(FunctionMap.empty());
		System.out.println(desugaredExpression.visit(new FancyExtendedToString<>()).apply(new StringBuilder()));
		System.out.println(desugaredExpression.visit(new FancyIntegerAlgebra()).apply(FunctionMap.empty()));

		FunctionDefinition<Integer> nestedFunctionExpression = new FunctionDefinition<>("square",
				new Times<>(new Variable<>("p"), new Variable<>("p")),
				new FunctionApplication<>("square", new FunctionApplication<>("square", new Constant<>(5))));

		desugaredExpression = nestedFunctionExpression.visit(new FunctionDesugar<>()).apply(FunctionMap.empty());
		System.out.println(desugaredExpression.visit(new FancyExtendedToString<>()).apply(new StringBuilder()));
		System.out.println(desugaredExpression.visit(new FancyIntegerAlgebra()).apply(FunctionMap.empty()));

	}

	static class FunctionIntegerAlgebra implements
			FunctionExpressionVisitor<Integer, BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer>> {
		@Override
		public BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> visit(
				Constant<Integer> constant, Integer value) {
			return (env, fenv) -> {
				return value;
			};
		}

		@Override
		public BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> visit(
				Plus<Integer> plus,
				BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> left,
				BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> right) {
			return (env, fenv) -> {
				return left.apply(env, fenv) + right.apply(env, fenv);
			};
		}

		@Override
		public BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> visit(
				Times<Integer> times,
				BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> left,
				BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> right) {
			return (env, fenv) -> {
				return left.apply(env, fenv) * right.apply(env, fenv);
			};
		}

		@Override
		public BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> visit(
				Variable<Integer> variable, String name) {
			return (env, fenv) -> {
				return env.apply(name);
			};
		}

		@Override
		public BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> visit(
				Assignment<Integer> assignment, String name,
				BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> value,
				BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> next) {
			return (env, fenv) -> {
				return next.apply(FunctionMap.put(env, name, value.apply(env, fenv)), fenv);
			};
		}

		@Override
		public BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> visit(
				FunctionDefinition<Integer> functionDefinition, String name,
				BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> body,
				BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> next) {
			return (env, fenv) -> {
				return next.apply(env, FunctionMap.put(fenv, name, functionDefinition.body));
			};
		}

		@Override
		public BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> visit(
				FunctionApplication<Integer> functionApplication, String name,
				BiFunction<Function<String, Integer>, Function<String, Expression<Integer>>, Integer> parameter) {
			return (env, fenv) -> {
				return fenv.apply(name).visit(this).apply(FunctionMap.put(env, "p", parameter.apply(env, fenv)), fenv);
			};
		}
	}

	@Test
	void testFunctionEvaluation() {
		FunctionDefinition<Integer> functionExpression = new FunctionDefinition<>("square",
				new Times<>(new Variable<>("p"), new Variable<>("p")),
				new FunctionApplication<>("square", new Constant<>(5)));

		System.out.println(
				functionExpression.visit(new FunctionIntegerAlgebra()).apply(FunctionMap.empty(), FunctionMap.empty()));

		FunctionDefinition<Integer> nestedFunctionExpression = new FunctionDefinition<>("square",
				new Times<>(new Variable<>("p"), new Variable<>("p")),
				new FunctionApplication<>("square", new FunctionApplication<>("square", new Constant<>(5))));
		System.out.println(nestedFunctionExpression.visit(new FunctionIntegerAlgebra()).apply(FunctionMap.empty(),
				FunctionMap.empty()));

	}

	static class BetterFunctionIntegerAlgebra implements
			FunctionExpressionVisitor<Integer, Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>>> {
		@Override
		public Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> visit(
				Constant<Integer> constant, Integer value) {
			return fenv -> env -> {
				return value;
			};
		}

		@Override
		public Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> visit(
				Plus<Integer> plus,
				Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> left,
				Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> right) {
			return fenv -> {
				Function<Function<String, Integer>, Integer> l = left.apply(fenv);
				Function<Function<String, Integer>, Integer> r = right.apply(fenv);
				return env -> {
					return l.apply(env) + r.apply(env);
				};
			};
		}

		@Override
		public Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> visit(
				Times<Integer> times,
				Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> left,
				Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> right) {
			return fenv -> {
				Function<Function<String, Integer>, Integer> l = left.apply(fenv);
				Function<Function<String, Integer>, Integer> r = right.apply(fenv);
				return env -> {
					return l.apply(env) * r.apply(env);
				};
			};
		}

		@Override
		public Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> visit(
				Variable<Integer> variable, String name) {
			return fenv -> env -> {
				return env.apply(name);
			};
		}

		@Override
		public Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> visit(
				Assignment<Integer> assignment, String name,
				Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> value,
				Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> next) {
			return fenv -> {
				Function<Function<String, Integer>, Integer> v = value.apply(fenv);
				Function<Function<String, Integer>, Integer> n = next.apply(fenv);
				return env -> {
					return n.apply(FunctionMap.put(env, name, v.apply(env)));
				};
			};
		}

		@Override
		public Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> visit(
				FunctionDefinition<Integer> functionDefinition, String name,
				Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> body,
				Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> next) {
			return fenv -> {
				Function<Function<String, Integer>, Integer> b = body.apply(fenv);
				return next.apply(FunctionMap.put(fenv, name, b));
//				Function<Function<String, Integer>, Integer> n = next.apply(FunctionMap.put(fenv, name, b));
//				return env -> {
//					return n.apply(env);
//				};
			};
		}

		@Override
		public Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> visit(
				FunctionApplication<Integer> functionApplication, String name,
				Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> parameter) {
			return fenv -> {
				Function<Function<String, Integer>, Integer> f = fenv.apply(name);
				Function<Function<String, Integer>, Integer> p = parameter.apply(fenv);
				return env -> {
					return f.apply(FunctionMap.put(env, "p", p.apply(env)));
				};
			};
		}
	}

	@Test
	void testBetterFunctionEvaluation() {
		FunctionDefinition<Integer> functionExpression = new FunctionDefinition<>("square",
				new Times<>(new Variable<>("p"), new Variable<>("p")),
				new FunctionApplication<>("square", new Constant<>(5)));

		Function<Function<String, Function<Function<String, Integer>, Integer>>, Function<Function<String, Integer>, Integer>> outerContinuation = functionExpression
				.visit(new BetterFunctionIntegerAlgebra());
		Function<Function<String, Integer>, Integer> innerContinuation = outerContinuation.apply(FunctionMap.empty());
		System.out.println(innerContinuation.apply(FunctionMap.empty()));

		FunctionDefinition<Integer> nestedFunctionExpression = new FunctionDefinition<>("square",
				new Times<>(new Variable<>("p"), new Variable<>("p")),
				new FunctionApplication<>("square", new FunctionApplication<>("square", new Constant<>(5))));
		outerContinuation = nestedFunctionExpression.visit(new BetterFunctionIntegerAlgebra());
		innerContinuation = outerContinuation.apply(FunctionMap.empty());
		System.out.println(innerContinuation.apply(FunctionMap.empty()));

		FunctionDefinition<Integer> recursiveFunctionExpression = new FunctionDefinition<>("factorial",
				new Times<>(new Variable<>("p"),
						new FunctionApplication<>("factorial", new Plus<>(new Variable<>("p"), new Constant<>(-1)))),
				new FunctionApplication<>("factorial", new Constant<>(5)));
		assertThrows(IllegalArgumentException.class,
				() -> recursiveFunctionExpression.visit(new BetterFunctionIntegerAlgebra()).apply(FunctionMap.empty()));

	}

	static class ConditionalExpression<T> extends Expression<T> {
		Expression<T> test;
		Expression<T> yes;
		Expression<T> no;

		public ConditionalExpression(Expression<T> test, Expression<T> yes, Expression<T> no) {
			this.test = test;
			this.yes = yes;
			this.no = no;
		}

		public <V> V visit(ConditionalExpressionVisitor<T, V> visitor) {
			return visitor.visit(this, test.visit(visitor), yes.visit(visitor), no.visit(visitor));
		}

		@Override
		public <V> V visit(SimpleExpressionVisitor<T, V> visitor) {
			if (visitor instanceof ConditionalExpressionVisitor<?, ?>) {
				return visit((ConditionalExpressionVisitor<T, V>) visitor);
			}
			throw new IllegalArgumentException();
		}
	}

	private static interface ConditionalExpressionVisitor<T, V> extends ExpressionVisitor<T, V> {
		V visit(ConditionalExpression<T> conditional, V test, V yes, V no);
	}

	static class ConditionalIntegerAlgebra extends FancyIntegerAlgebra
			implements ConditionalExpressionVisitor<Integer, Function<Function<String, Integer>, Integer>> {
		@Override
		public Function<Function<String, Integer>, Integer> visit(ConditionalExpression<Integer> conditional,
				Function<Function<String, Integer>, Integer> test, Function<Function<String, Integer>, Integer> yes,
				Function<Function<String, Integer>, Integer> no) {
			return env -> test.apply(env) > 0 ? yes.apply(env) : no.apply(env);
		}
	}
	
	public static int evaluation(int p) {
		return p * p > 10 ? p + 3 : p * 2;
	}

	@Test
	void testConditional() {
		Expression<Integer> equals = new ConditionalExpression<>(new Plus<>(new Variable<>("p"), new Constant<>(1)),
				new ConditionalExpression<>(
						new Plus<>(new Times<>(new Constant<>(-1), new Variable<>("p")), new Constant<>(1)),
						new Constant<>(0), new Constant<>(1)),
				new Constant<>(1));
		Function<Function<String, Integer>, Integer> evaluation = equals.visit(new ConditionalIntegerAlgebra());
		System.out.println(evaluation.apply(FunctionMap.put(FunctionMap.empty(), "p", -2)));
		System.out.println(evaluation.apply(FunctionMap.put(FunctionMap.empty(), "p", -1)));
		System.out.println(evaluation.apply(FunctionMap.put(FunctionMap.empty(), "p", 0)));
		System.out.println(evaluation.apply(FunctionMap.put(FunctionMap.empty(), "p", 1)));
		System.out.println(evaluation.apply(FunctionMap.put(FunctionMap.empty(), "p", 2)));

		Expression<Integer> expression = new ConditionalExpression<>(
				new Plus<>(new Times<>(new Variable<>("p"), new Variable<>("p")), new Constant<>(-10)),
				new Plus<>(new Variable<>("p"), new Constant<>(3)),
				new Times<>(new Variable<>("p"), new Constant<>(2)));

		evaluation = expression.visit(new ConditionalIntegerAlgebra());

		System.out.println(evaluation.apply(FunctionMap.put(FunctionMap.empty(), "p", 1)));
		System.out.println(evaluation.apply(FunctionMap.put(FunctionMap.empty(), "p", 2)));
		System.out.println(evaluation.apply(FunctionMap.put(FunctionMap.empty(), "p", 3)));
		System.out.println(evaluation.apply(FunctionMap.put(FunctionMap.empty(), "p", 4)));
		System.out.println(evaluation.apply(FunctionMap.put(FunctionMap.empty(), "p", 5)));
		System.out.println(evaluation.apply(FunctionMap.put(FunctionMap.empty(), "p", 6)));
		
		System.out.println(evaluation(1));
		System.out.println(evaluation(2));
		System.out.println(evaluation(3));
		System.out.println(evaluation(4));
		System.out.println(evaluation(5));
		System.out.println(evaluation(6));
	}

	static class RecursiveFunctionIntegerAlgebra<F extends BiFunction<Function<String, Integer>, Function<String, F>, Integer>>
			implements FunctionExpressionVisitor<Integer, F>, ConditionalExpressionVisitor<Integer, F> {
		@SuppressWarnings("unchecked")
		private F cast(BiFunction<Function<String, Integer>, Function<String, F>, Integer> value) {
			return (F) value;
		}

		@Override
		public F visit(Constant<Integer> constant, Integer value) {
			return cast((env, fenv) -> value);
		}

		@Override
		public F visit(Plus<Integer> plus, F left, F right) {
			return cast((env, fenv) -> left.apply(env, fenv) + right.apply(env, fenv));
		}

		@Override
		public F visit(Times<Integer> times, F left, F right) {
			return cast((env, fenv) -> left.apply(env, fenv) * right.apply(env, fenv));
		}

		@Override
		public F visit(Variable<Integer> variable, String name) {
			return cast((env, fenv) -> env.apply(name));
		}

		@Override
		public F visit(Assignment<Integer> assignment, String name, F value, F next) {
			return cast((env, fenv) -> next.apply(FunctionMap.put(env, name, value.apply(env, fenv)), fenv));
		}

		@Override
		public F visit(FunctionDefinition<Integer> functionDefinition, String name, F body, F next) {
			return cast((env, fenv) -> next.apply(env, FunctionMap.put(fenv, name, body)));
		}

		@Override
		public F visit(FunctionApplication<Integer> functionApplication, String name, F parameter) {
			return cast(
					(env, fenv) -> fenv.apply(name).apply(FunctionMap.put(env, "p", parameter.apply(env, fenv)), fenv));
		}

		@Override
		public F visit(ConditionalExpression<Integer> conditional, F test, F yes, F no) {
			return cast((env, fenv) -> test.apply(env, fenv) > 0 ? yes.apply(env, fenv) : no.apply(env, fenv));
		}
	}

	@Test
	<F extends BiFunction<Function<String, Integer>, Function<String, F>, Integer>> void testRecursiveFunctionEvaluation() {
		FunctionDefinition<Integer> functionExpression = new FunctionDefinition<>("square",
				new Times<>(new Variable<>("p"), new Variable<>("p")),
				new FunctionApplication<>("square", new Constant<>(5)));

		F continuation = functionExpression.visit(new RecursiveFunctionIntegerAlgebra<F>());
		System.out.println(continuation.apply(FunctionMap.empty(), FunctionMap.empty()));

		FunctionDefinition<Integer> nestedFunctionExpression = new FunctionDefinition<>("square",
				new Times<>(new Variable<>("p"), new Variable<>("p")),
				new FunctionApplication<>("square", new FunctionApplication<>("square", new Constant<>(5))));
		continuation = nestedFunctionExpression.visit(new RecursiveFunctionIntegerAlgebra<F>());
		System.out.println(continuation.apply(FunctionMap.empty(), FunctionMap.empty()));

		FunctionDefinition<Integer> recursiveFunctionExpression = new FunctionDefinition<>("factorial",
				new Times<>(new Variable<>("p"),
						new FunctionApplication<>("factorial", new Plus<>(new Variable<>("p"), new Constant<>(-1)))),
				new FunctionApplication<>("factorial", new Constant<>(5)));
		F failingContinuation = recursiveFunctionExpression.visit(new RecursiveFunctionIntegerAlgebra<F>());
		assertThrows(StackOverflowError.class,
				() -> failingContinuation.apply(FunctionMap.empty(), FunctionMap.empty()));

		FunctionDefinition<Integer> fixedRecursiveFunctionExpression = new FunctionDefinition<>("factorial",
				new ConditionalExpression<>(new Variable<>("p"),
						new Times<>(new Variable<>("p"),
								new FunctionApplication<>("factorial",
										new Plus<>(new Variable<>("p"), new Constant<>(-1)))),
						new Constant<>(1)),
				new FunctionApplication<>("factorial", new Constant<>(5)));
		continuation = fixedRecursiveFunctionExpression.visit(new RecursiveFunctionIntegerAlgebra<F>());
		System.out.println(continuation.apply(FunctionMap.empty(), FunctionMap.empty()));
	}

	public static FunctionalList<Integer> times2(FunctionalList<Integer> list) {
		try {
			Pair<FunctionalList<Integer>, Integer> parts = FunctionList.remove(list);
			return FunctionList.add(times2(parts.getA()), 2 * parts.getB());
		} catch (IllegalArgumentException e) {
			return FunctionList.empty();
		}
	}

	@Test
	void testDouble() {
		FunctionalList<Integer> list = FunctionList.add(FunctionList.add(FunctionList.add(FunctionList.empty(), 3), 2),
				1);
		assertEquals(1, FunctionList.remove(list).getB());
		assertEquals(2, FunctionList.remove(FunctionList.remove(list).getA()).getB());
		assertEquals(3, FunctionList.remove(FunctionList.remove(FunctionList.remove(list).getA()).getA()).getB());

		FunctionalList<Integer> doubled = times2(list);
		assertEquals(2, FunctionList.remove(doubled).getB());
		assertEquals(4, FunctionList.remove(FunctionList.remove(doubled).getA()).getB());
		assertEquals(6, FunctionList.remove(FunctionList.remove(FunctionList.remove(doubled).getA()).getA()).getB());
	}

	public static <T> FunctionalList<T> reverse(FunctionalList<T> list, FunctionalList<T> accumulator) {
		try {
			Pair<FunctionalList<T>, T> parts = FunctionList.remove(list);
			return reverse(parts.getA(), FunctionList.add(accumulator, parts.getB()));
		} catch (IllegalArgumentException e) {
			return accumulator;
		}
	}

	@Test
	void testReverse() {
		FunctionalList<Integer> list = FunctionList.add(FunctionList.add(FunctionList.add(FunctionList.empty(), 3), 2),
				1);
		assertEquals(1, FunctionList.remove(list).getB());
		assertEquals(2, FunctionList.remove(FunctionList.remove(list).getA()).getB());
		assertEquals(3, FunctionList.remove(FunctionList.remove(FunctionList.remove(list).getA()).getA()).getB());

		FunctionalList<Integer> reversed = reverse(list, FunctionList.empty());
		assertEquals(3, FunctionList.remove(reversed).getB());
		assertEquals(2, FunctionList.remove(FunctionList.remove(reversed).getA()).getB());
		assertEquals(1, FunctionList.remove(FunctionList.remove(FunctionList.remove(reversed).getA()).getA()).getB());
	}

	public static int factorial(int n) {
		if (n < 1) {
			return 1;
		}
		return n * factorial(n - 1);
	}

	public static int factorial(int n, int a) {
		if (n < 1) {
			return a;
		}
		return factorial(n - 1, n * a);
	}

	public static int factorialLoop(int n, int a) {
		while (true) {
			if (n < 1) {
				return a;
			}
			int newN = n - 1;
			int newA = n * a;
			n = newN;
			a = newA;
		}
	}

	@Test
	void testFactorial() {
		System.out.println(factorial(5));
		System.out.println(factorial(5, 1));
		System.out.println(factorialLoop(5, 1));
	}

	public static <T> FunctionalList<T> reverseLoop(FunctionalList<T> list, FunctionalList<T> accumulator) {
		try {
			while (true) {
				Pair<FunctionalList<T>, T> parts = FunctionList.remove(list);
				FunctionalList<T> newList = parts.getA();
				FunctionalList<T> newAccumulator = FunctionList.add(accumulator, parts.getB());
				list = newList;
				accumulator = newAccumulator;
			}
		} catch (IllegalArgumentException e) {
			return accumulator;
		}
	}
	
	@Test
	void testReverseLoop() {
		FunctionalList<Integer> list = FunctionList.add(FunctionList.add(FunctionList.add(FunctionList.empty(), 3), 2),
				1);
		assertEquals(1, FunctionList.remove(list).getB());
		assertEquals(2, FunctionList.remove(FunctionList.remove(list).getA()).getB());
		assertEquals(3, FunctionList.remove(FunctionList.remove(FunctionList.remove(list).getA()).getA()).getB());

		FunctionalList<Integer> reversed = reverseLoop(list, FunctionList.empty());
		assertEquals(3, FunctionList.remove(reversed).getB());
		assertEquals(2, FunctionList.remove(FunctionList.remove(reversed).getA()).getB());
		assertEquals(1, FunctionList.remove(FunctionList.remove(FunctionList.remove(reversed).getA()).getA()).getB());
	}


	@Test
	void testLists() {
		FunctionalList<Integer> list = FunctionList.empty();
		for (int i = 0; i < 10000; i++) {
			list = FunctionList.add(list, i);
		}
		FunctionalList<Integer> l = list;

		assertThrows(StackOverflowError.class, () -> times2(l));
		assertThrows(StackOverflowError.class, () -> reverse(l, FunctionList.empty()));
		reverseLoop(l, FunctionList.empty());
	}
}
