diff --git a/core/src/main/java/com/garciat/typeclasses/TypeClasses.java b/core/src/main/java/com/garciat/typeclasses/TypeClasses.java index fd4894f..c0da10e 100644 --- a/core/src/main/java/com/garciat/typeclasses/TypeClasses.java +++ b/core/src/main/java/com/garciat/typeclasses/TypeClasses.java @@ -1,32 +1,65 @@ package com.garciat.typeclasses; +import com.garciat.typeclasses.api.Lazy; import com.garciat.typeclasses.api.Ty; import com.garciat.typeclasses.impl.Match; +import com.garciat.typeclasses.impl.ParsedType; +import com.garciat.typeclasses.impl.Resolution; import com.garciat.typeclasses.impl.utils.Either; import com.garciat.typeclasses.runtime.Runtime; import com.garciat.typeclasses.runtime.RuntimeWitnessSystem; import java.lang.reflect.InvocationTargetException; -import java.util.List; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; public final class TypeClasses { private TypeClasses() {} public static T witness(Ty ty) { - Object instance = - switch (RuntimeWitnessSystem.resolve(ty.type(), TypeClasses::invoke)) { + Resolution.Result methodTree = + switch (RuntimeWitnessSystem.resolve(ty.type())) { case Either.Right(var r) -> r; case Either.Left(var error) -> throw new WitnessResolutionException(error.format()); }; + Object instance = walk(new HashMap<>(), methodTree); + @SuppressWarnings("unchecked") T typedInstance = (T) instance; return typedInstance; } + private static Object walk( + Map, Object> cache, + Resolution.Result tree) { + return switch (tree) { + case Resolution.Result.Node(var match, var dependencies) -> { + Object[] args = dependencies.stream().map(dep -> walk(cache, dep)).toArray(); + + Object instance = invoke(match, args); + + cache.put(match.witnessType(), instance); + + yield instance; + } + case Resolution.Result.LazyLookup(var target) -> + (Lazy) + () -> + Optional.ofNullable(cache.get(target)) + .orElseThrow( + () -> + new WitnessResolutionException( + "BUG: expected cached instance for %s" + .formatted(target.format()))); + case Resolution.Result.LazyWrap(var under) -> (Lazy) () -> walk(cache, under); + }; + } + private static Object invoke( - Match match, List args) { + Match match, Object[] args) { try { - return match.ctor().method().java().invoke(null, args.toArray()); + return match.ctor().method().java().invoke(null, args); } catch (IllegalAccessException e) { throw new IllegalStateException("BUG: expected witness constructor method to be public", e); } catch (InvocationTargetException e) { diff --git a/core/src/main/java/com/garciat/typeclasses/api/Lazy.java b/core/src/main/java/com/garciat/typeclasses/api/Lazy.java new file mode 100644 index 0000000..3f5dac2 --- /dev/null +++ b/core/src/main/java/com/garciat/typeclasses/api/Lazy.java @@ -0,0 +1,5 @@ +package com.garciat.typeclasses.api; + +public interface Lazy { + A get(); +} diff --git a/core/src/main/java/com/garciat/typeclasses/api/Out.java b/core/src/main/java/com/garciat/typeclasses/api/Out.java new file mode 100644 index 0000000..690b829 --- /dev/null +++ b/core/src/main/java/com/garciat/typeclasses/api/Out.java @@ -0,0 +1,10 @@ +package com.garciat.typeclasses.api; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE_PARAMETER}) +public @interface Out {} diff --git a/core/src/main/java/com/garciat/typeclasses/impl/ParsedType.java b/core/src/main/java/com/garciat/typeclasses/impl/ParsedType.java index 8e0dfe2..e1e755e 100644 --- a/core/src/main/java/com/garciat/typeclasses/impl/ParsedType.java +++ b/core/src/main/java/com/garciat/typeclasses/impl/ParsedType.java @@ -3,26 +3,42 @@ import java.util.List; public sealed interface ParsedType { - record Var(V repr) implements ParsedType {} + record Var(TyParam ref) implements ParsedType {} + + record Out(ParsedType under) implements ParsedType {} record App(ParsedType fun, ParsedType arg) implements ParsedType {} record ArrayOf(ParsedType elementType) implements ParsedType {} - record Const(C repr, List> typeParams) implements ParsedType {} + record Const(C repr, List> typeParams) implements ParsedType {} record Primitive(P repr) implements ParsedType {} record Wildcard() implements ParsedType {} + record Lazy(ParsedType under) implements ParsedType {} + + record TyParam(V repr, boolean isOut) { + public ParsedType wrapOut(ParsedType under) { + return isOut ? new Out<>(under) : under; + } + + @Override + public String toString() { + return (isOut ? "&" : "") + repr; + } + } + default String format() { return switch (this) { case Var(var repr) -> repr.toString(); + case Out(var under) -> "Out<" + under.format() + ">"; case Const(var repr, var typeParams) -> repr.toString() + typeParams.stream() - .map(ParsedType::format) + .map(TyParam::toString) .reduce((a, b) -> a + ", " + b) .map(s -> "[" + s + "]") .orElse(""); @@ -30,6 +46,7 @@ case Const(var repr, var typeParams) -> case ArrayOf(var elem) -> elem.format() + "[]"; case Primitive(var repr) -> repr.toString(); case Wildcard() -> "?"; + case Lazy(var repr) -> "Lazy<" + repr.format() + ">"; }; } } diff --git a/core/src/main/java/com/garciat/typeclasses/impl/Resolution.java b/core/src/main/java/com/garciat/typeclasses/impl/Resolution.java index 6158b01..5839d15 100644 --- a/core/src/main/java/com/garciat/typeclasses/impl/Resolution.java +++ b/core/src/main/java/com/garciat/typeclasses/impl/Resolution.java @@ -1,59 +1,276 @@ package com.garciat.typeclasses.impl; +import static com.garciat.typeclasses.impl.utils.AutoCloseables.around; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.toUnmodifiableList; +import static java.util.stream.Collectors.toUnmodifiableSet; + +import com.garciat.typeclasses.impl.ParsedType.Var; import com.garciat.typeclasses.impl.utils.Either; import com.garciat.typeclasses.impl.utils.Lists; import com.garciat.typeclasses.impl.utils.Maybe; +import com.garciat.typeclasses.impl.utils.Sets; import com.garciat.typeclasses.impl.utils.ZeroOneMore; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; -import java.util.function.BiFunction; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; import java.util.function.Function; import java.util.stream.Collectors; public final class Resolution { private Resolution() {} - public static Either, R> resolve( - Function, List>> constructors, - BiFunction, List, R> build, + public sealed interface Result { + record Node(Match match, List> children) + implements Result {} + + record LazyLookup(ParsedType target) implements Result {} + + record LazyWrap(Result under) implements Result {} + } + + public static Either, Result> resolve( + Function>> constructors, ParsedType target) { + return resolveRec(new HashSet<>(), constructors, target); + } + + private static Either, Result> resolveRec( + Set> seen, + Function>> constructors, ParsedType target) { - var candidates = - OverlappingInstances.reduce( - Maybe.mapMaybe(findWitnesses(target, constructors), ctor -> match(ctor, target))); + if (target instanceof ParsedType.Lazy(var under)) { + if (seen.contains(under)) { + return Either.right(new Result.LazyLookup<>(under)); + } else { + try (var _ = around(() -> seen.add(under), () -> seen.remove(under))) { + return resolveRec(seen, constructors, under).map(Result.LazyWrap::new); + } + } + } + + var attempts = + Either.partition( + Lists.map( + Witnesses.findWitnesses(constructors, target), + ctor -> match(seen, constructors, ctor, target))); + + var candidates = OverlappingInstances.reduce(attempts.snd()); return switch (ZeroOneMore.of(candidates)) { - case ZeroOneMore.Zero() -> Either.left(new Failure.NotFound<>(target)); + case ZeroOneMore.Zero() -> + Either.left(new Failure.NoMatch<>(target, attempts.fst(), attempts.snd())); case ZeroOneMore.More(var matches) -> Either.left(new Failure.Ambiguous<>(target, matches)); - case ZeroOneMore.One(var c) -> - Either.traverse(c.dependencies(), t -> resolve(constructors, build, t)) - .map(children -> build.apply(c, children)) + case ZeroOneMore.One(var match) -> + Either.traverse(match.dependencies(), t -> resolveRec(seen, constructors, t)) + .>map(children -> new Result.Node<>(match, children)) .mapLeft(f -> new Failure.Nested<>(target, f)); }; } - private static List> findWitnesses( - ParsedType target, - Function, List>> constructors) { - return switch (target) { - case ParsedType.App(var fun, var arg) -> - Lists.concat(findWitnesses(fun, constructors), findWitnesses(arg, constructors)); - case ParsedType.Const c -> constructors.apply(c); - case ParsedType.Var(_), - ParsedType.ArrayOf(_), - ParsedType.Primitive(_), - ParsedType.Wildcard() -> - List.of(); + private static Either, Match> match( + Set> seen, + Function>> constructors, + WitnessConstructor ctor, + ParsedType target) { + return switch (Unification.unify(ctor.returnType(), target)) { + case Maybe.Nothing() -> Either.left(new MatchFailure.HeadMismatch<>(ctor)); + case Maybe.Just(var returnSubst) -> { + List> dependencies = + Unification.substituteAll(returnSubst, ctor.paramTypes()); + + TreeMap>> nodesByInDegree = + dependencies.stream() + .map(Resolution::parseNode) + .collect(groupingBy(n -> n.in().size(), TreeMap::new, toUnmodifiableList())); + + if (!nodesByInDegree.isEmpty() && nodesByInDegree.firstKey() != 0) { + // There is a cycle in the dependency graph + yield Either.left(new MatchFailure.Cycle<>(ctor, dependencies)); + } + + Map, ParsedType> substitution = new HashMap<>(returnSubst); + + for (List> stratum : nodesByInDegree.sequencedValues()) { + for (Node node : stratum) { + { + var missing = Sets.difference(node.in(), substitution.keySet()); + if (!missing.isEmpty()) { + // Some input variable has not been satisfied yet + yield Either.left(new MatchFailure.UnboundVariables<>(ctor, missing)); + } + } + + switch (flatten( + resolveRec( + seen, constructors, Unification.substitute(substitution, node.type())))) { + case Either.Right(Result.Node(var possible, _)) -> { + switch (Unification.unify( + Types.unwrapOut1(node.type()), Types.unwrapOut1(possible.witnessType()))) { + case Maybe.Just(var childSubst) -> { + { + var missing = Sets.difference(node.out(), childSubst.keySet()); + if (!missing.isEmpty()) { + // Some output variable has not been satisfied + yield Either.left(new MatchFailure.UnproductiveConstraint<>(ctor, missing)); + } + } + + for (Var out : childSubst.keySet()) { + ParsedType outT = childSubst.get(out); + + if (substitution.get(out) instanceof ParsedType existingT + && !existingT.equals(outT)) { + // Conflicting substitutions + yield Either.left( + new MatchFailure.ConflictingSubstitution<>(ctor, out, existingT, outT)); + } + + substitution.put(out, outT); + } + } + case Maybe.Nothing() -> { + // Child witness does not match expected type + yield Either.left( + new MatchFailure.ResolvedConstraintMismatch<>( + ctor, + Types.unwrapOut1(node.type()), + Types.unwrapOut1(possible.witnessType()))); + } + } + } + case Either.Right(Result.LazyWrap(_)) -> + throw new IllegalStateException( + "flatten should have eliminated LazyWrap cases here"); + case Either.Right(Result.LazyLookup(_)) -> { + // For now, we just treat them as resolved constraints :shrug: + } + case Either.Left(var error) -> { + // Could not resolve child witness + yield Either.left( + new MatchFailure.UnresolvedConstraint<>(ctor, node.type(), error)); + } + } + } + } + + yield Either.right( + new Match<>( + ctor, + Unification.substituteAll(substitution, ctor.paramTypes()), + Unification.substitute(substitution, ctor.returnType()))); + } }; } - private static Maybe> match( - WitnessConstructor ctor, ParsedType target) { - return Unification.unify(ctor.returnType(), target) - .map(map -> Unification.substituteAll(map, ctor.paramTypes())) - .map(dependencies -> new Match<>(ctor, dependencies, target)); + private static Node parseNode(ParsedType type) { + return new Node<>( + type, + Types.findOutVars(type).collect(toUnmodifiableSet()), + Types.findVars(type).collect(toUnmodifiableSet())); + } + + private static Either, Result> flatten( + Either, Result> result) { + return switch (result) { + case Either.Right(Result.LazyWrap(var under)) -> flatten(Either.right(under)); + default -> result; + }; + } + + private record Node( + ParsedType type, Set> out, Set> in) {} + + public sealed interface MatchFailure { + record HeadMismatch(WitnessConstructor ctor) + implements MatchFailure {} + + record Cycle(WitnessConstructor ctor, List> types) + implements MatchFailure {} + + record UnresolvedConstraint( + WitnessConstructor ctor, + ParsedType constraint, + Failure cause) + implements MatchFailure {} + + record ResolvedConstraintMismatch( + WitnessConstructor ctor, + ParsedType expected, + ParsedType actual) + implements MatchFailure {} + + record UnboundVariables( + WitnessConstructor ctor, Set> variables) + implements MatchFailure {} + + record UnproductiveConstraint( + WitnessConstructor ctor, Set> variables) + implements MatchFailure {} + + record ConflictingSubstitution( + WitnessConstructor ctor, + Var variable, + ParsedType existing, + ParsedType conflicting) + implements MatchFailure {} + + default String format() { + return switch (this) { + case HeadMismatch(var ctor) -> + "Witness constructor " + ctor.format() + " does not match the target type."; + case Cycle(var ctor, var types) -> + "Witness constructor " + + ctor.format() + + " has cyclic dependencies: " + + types.stream().map(ParsedType::format).collect(Collectors.joining(", ")); + case UnresolvedConstraint(var ctor, var constraint, var cause) -> + "Could not resolve constraint " + + constraint.format() + + " for witness constructor " + + ctor.format() + + ":\nCaused by: " + + cause.format().indent(2); + case ResolvedConstraintMismatch(var ctor, var expected, var actual) -> + "Resolved constraint for witness constructor " + + ctor.format() + + " does not match expected type: expected " + + expected.format() + + ", got " + + actual.format() + + "."; + case UnboundVariables(var ctor, var variables) -> + "Witness constructor " + + ctor.format() + + " has unbound input variables: " + + variables.stream().map(Var::format).collect(Collectors.joining(", ")); + case UnproductiveConstraint(var ctor, var variables) -> + "Witness constructor " + + ctor.format() + + " has unproductive output variables: " + + variables.stream().map(Var::format).collect(Collectors.joining(", ")); + case ConflictingSubstitution(var ctor, var variable, var existing, var conflicting) -> + "Witness constructor " + + ctor.format() + + " has conflicting substitutions for variable " + + variable.format() + + ": existing substitution " + + existing.format() + + ", conflicting substitution " + + conflicting.format() + + "."; + }; + } } public sealed interface Failure { - record NotFound(ParsedType target) implements Failure {} + record NoMatch( + ParsedType target, + List> failures, + List> matches) + implements Failure {} record Ambiguous(ParsedType target, List> candidates) implements Failure {} @@ -63,7 +280,22 @@ record Nested(ParsedType target, Failure cause) default String format() { return switch (this) { - case NotFound(var target) -> "No witness found for type: " + target.format(); + case NoMatch(var target, var failures, var matches) -> + "No witnesses found for type: " + + target.format() + + "\nFailures:\n" + + failures.stream() + .map(MatchFailure::format) + .collect(Collectors.joining("\n")) + .indent(2) + + (matches.isEmpty() + ? "" + : "\nPartial matches:\n" + + matches.stream() + .map(Match::ctor) + .map(WitnessConstructor::format) + .collect(Collectors.joining("\n")) + .indent(2)); case Ambiguous(var target, var candidates) -> "Ambiguous witnesses found for type: " + target.format() diff --git a/core/src/main/java/com/garciat/typeclasses/impl/Types.java b/core/src/main/java/com/garciat/typeclasses/impl/Types.java new file mode 100644 index 0000000..c03eb56 --- /dev/null +++ b/core/src/main/java/com/garciat/typeclasses/impl/Types.java @@ -0,0 +1,45 @@ +package com.garciat.typeclasses.impl; + +import java.util.stream.Stream; + +public final class Types { + private Types() {} + + public static Stream> findOutVars(ParsedType type) { + return switch (type) { + case ParsedType.Var(_) -> Stream.of(); + case ParsedType.Out(ParsedType.Var v) -> Stream.of(v); + case ParsedType.Out(_) -> Stream.of(); + case ParsedType.App(var fun, var arg) -> Stream.concat(findOutVars(fun), findOutVars(arg)); + case ParsedType.ArrayOf(var elem) -> findOutVars(elem); + case ParsedType.Lazy(var under) -> findOutVars(under); + case ParsedType.Const(_, _), ParsedType.Primitive(_), ParsedType.Wildcard() -> Stream.of(); + }; + } + + public static Stream> findVars(ParsedType type) { + return switch (type) { + case ParsedType.Var v -> Stream.of(v); + case ParsedType.Out(_) -> Stream.of(); + case ParsedType.App(var fun, var arg) -> Stream.concat(findVars(fun), findVars(arg)); + case ParsedType.ArrayOf(var elem) -> findVars(elem); + case ParsedType.Lazy(var under) -> findVars(under); + case ParsedType.Const(_, _), ParsedType.Primitive(_), ParsedType.Wildcard() -> Stream.of(); + }; + } + + /** Unwraps one level of Out from the given type. */ + public static ParsedType unwrapOut1(ParsedType type) { + return switch (type) { + case ParsedType.Out(var under) -> under; + case ParsedType.App(var fun, var arg) -> + new ParsedType.App<>(unwrapOut1(fun), unwrapOut1(arg)); + case ParsedType.ArrayOf(var elem) -> new ParsedType.ArrayOf<>(unwrapOut1(elem)); + case ParsedType.Lazy(var t) -> new ParsedType.Lazy<>(unwrapOut1(t)); + case ParsedType.Var v -> v; + case ParsedType.Primitive p -> p; + case ParsedType.Const c -> c; + case ParsedType.Wildcard w -> w; + }; + } +} diff --git a/core/src/main/java/com/garciat/typeclasses/impl/Unification.java b/core/src/main/java/com/garciat/typeclasses/impl/Unification.java index 250c976..e3a88aa 100644 --- a/core/src/main/java/com/garciat/typeclasses/impl/Unification.java +++ b/core/src/main/java/com/garciat/typeclasses/impl/Unification.java @@ -3,6 +3,8 @@ import com.garciat.typeclasses.impl.ParsedType.App; import com.garciat.typeclasses.impl.ParsedType.ArrayOf; import com.garciat.typeclasses.impl.ParsedType.Const; +import com.garciat.typeclasses.impl.ParsedType.Lazy; +import com.garciat.typeclasses.impl.ParsedType.Out; import com.garciat.typeclasses.impl.ParsedType.Primitive; import com.garciat.typeclasses.impl.ParsedType.Var; import com.garciat.typeclasses.impl.ParsedType.Wildcard; @@ -22,6 +24,8 @@ private Unification() {} public static Maybe, ParsedType>> unify( ParsedType t1, ParsedType t2) { return switch (Pair.of(t1, t2)) { + case Pair(Lazy(var x), var t) -> unify(x, t); + case Pair(var t, Lazy(var x)) -> unify(t, x); case Pair(Var(_), Primitive(_)) -> Maybe.nothing(); // no primitives in generics case Pair(Var v, var t) -> Maybe.just(Map.of(v, t)); case Pair(Const(var repr1, _), Const(var repr2, _)) when repr1.equals(repr2) -> @@ -31,8 +35,12 @@ case Pair(App(var fun1, var arg1), App(var fun2, var arg2)) -> case Pair(ArrayOf(var elem1), ArrayOf(var elem2)) -> unify(elem1, elem2); case Pair(Primitive(var prim1), Primitive(var prim2)) when prim1.equals(prim2) -> Maybe.just(Map.of()); + // Wildcards can match anything case Pair(Wildcard(), _) -> Maybe.just(Map.of()); - default -> Maybe.nothing(); + // Out types match each other, regardless of their inner types + // The resolution algorithm will check compatibility later + case Pair(Out(_), Out(_)) -> Maybe.just(Map.of()); + case Pair(_, _) -> Maybe.nothing(); }; } @@ -40,8 +48,10 @@ public static ParsedType substitute( Map, ParsedType> map, ParsedType type) { return switch (type) { case Var var -> map.getOrDefault(var, var); + case Out(var under) -> new Out<>(substitute(map, under)); case App(var fun, var arg) -> new App<>(substitute(map, fun), substitute(map, arg)); case ArrayOf(var elem) -> new ArrayOf<>(substitute(map, elem)); + case Lazy(var t) -> new Lazy<>(substitute(map, t)); case Primitive p -> p; case Const c -> c; case Wildcard w -> w; diff --git a/core/src/main/java/com/garciat/typeclasses/impl/WitnessConstructor.java b/core/src/main/java/com/garciat/typeclasses/impl/WitnessConstructor.java index 0f30d0d..5b43d17 100644 --- a/core/src/main/java/com/garciat/typeclasses/impl/WitnessConstructor.java +++ b/core/src/main/java/com/garciat/typeclasses/impl/WitnessConstructor.java @@ -7,14 +7,15 @@ public record WitnessConstructor( M method, TypeClass.Witness.Overlap overlap, - List> typeParams, + List> typeParams, List> paramTypes, ParsedType returnType) { public String format() { return String.format( - "%s%s -> %s", + "%s = %s%s -> %s", + method.toString(), typeParams().stream() - .map(ParsedType::format) + .map(ParsedType.TyParam::toString) .reduce((a, b) -> a + " " + b) .map("∀ %s. "::formatted) .orElse(""), diff --git a/core/src/main/java/com/garciat/typeclasses/impl/Witnesses.java b/core/src/main/java/com/garciat/typeclasses/impl/Witnesses.java new file mode 100644 index 0000000..f6a8777 --- /dev/null +++ b/core/src/main/java/com/garciat/typeclasses/impl/Witnesses.java @@ -0,0 +1,27 @@ +package com.garciat.typeclasses.impl; + +import com.garciat.typeclasses.impl.utils.Lists; +import java.util.List; +import java.util.function.Function; + +final class Witnesses { + private Witnesses() {} + + static List> findWitnesses( + Function>> constructors, ParsedType target) { + return switch (target) { + case ParsedType.App(var fun1, ParsedType.App(var fun2, _)) -> + Lists.concat(findWitnesses(constructors, fun1), findWitnesses(constructors, fun2)); + case ParsedType.App(var fun, var arg) -> + Lists.concat(findWitnesses(constructors, fun), findWitnesses(constructors, arg)); + case ParsedType.Const c -> constructors.apply(c.repr()); + case ParsedType.Lazy(var under) -> findWitnesses(constructors, under); + case ParsedType.Var(_), + ParsedType.Out(_), + ParsedType.ArrayOf(_), + ParsedType.Primitive(_), + ParsedType.Wildcard() -> + List.of(); + }; + } +} diff --git a/core/src/main/java/com/garciat/typeclasses/impl/utils/AutoCloseables.java b/core/src/main/java/com/garciat/typeclasses/impl/utils/AutoCloseables.java new file mode 100644 index 0000000..c28e628 --- /dev/null +++ b/core/src/main/java/com/garciat/typeclasses/impl/utils/AutoCloseables.java @@ -0,0 +1,23 @@ +package com.garciat.typeclasses.impl.utils; + +public final class AutoCloseables { + private AutoCloseables() {} + + public static SafeAutoCloseable around(Runnable before, Runnable after) { + return new SafeAutoCloseable() { + { + before.run(); + } + + @Override + public void close() { + after.run(); + } + }; + } + + public interface SafeAutoCloseable extends AutoCloseable { + @Override + void close(); + } +} diff --git a/core/src/main/java/com/garciat/typeclasses/impl/utils/Either.java b/core/src/main/java/com/garciat/typeclasses/impl/utils/Either.java index 823dd34..235cfd4 100644 --- a/core/src/main/java/com/garciat/typeclasses/impl/utils/Either.java +++ b/core/src/main/java/com/garciat/typeclasses/impl/utils/Either.java @@ -2,7 +2,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.concurrent.Callable; import java.util.function.Function; public sealed interface Either { @@ -18,14 +17,6 @@ static Either right(R value) { return new Right<>(value); } - static Either call(Callable callable) { - try { - return right(callable.call()); - } catch (Exception e) { - return left(e); - } - } - default Either map(Function f) { return fold(Either::left, f.andThen(Either::right)); } @@ -60,4 +51,20 @@ case Right(R value) -> { } return right(result); } + + static Pair, List> partition(List> eithers) { + List lefts = new ArrayList<>(); + List rights = new ArrayList<>(); + for (Either either : eithers) { + switch (either) { + case Left(L value) -> { + lefts.add(value); + } + case Right(R value) -> { + rights.add(value); + } + } + } + return new Pair<>(lefts, rights); + } } diff --git a/core/src/main/java/com/garciat/typeclasses/impl/utils/Formatter.java b/core/src/main/java/com/garciat/typeclasses/impl/utils/Formatter.java deleted file mode 100644 index 3e97bbd..0000000 --- a/core/src/main/java/com/garciat/typeclasses/impl/utils/Formatter.java +++ /dev/null @@ -1,5 +0,0 @@ -package com.garciat.typeclasses.impl.utils; - -import java.util.function.Function; - -public interface Formatter extends Function {} diff --git a/core/src/main/java/com/garciat/typeclasses/impl/utils/Lists.java b/core/src/main/java/com/garciat/typeclasses/impl/utils/Lists.java index d25373c..e3378aa 100644 --- a/core/src/main/java/com/garciat/typeclasses/impl/utils/Lists.java +++ b/core/src/main/java/com/garciat/typeclasses/impl/utils/Lists.java @@ -2,8 +2,10 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.IntStream; public final class Lists { private Lists() {} @@ -12,6 +14,17 @@ public static List map(List list, Function return list.stream().map(f).collect(Collectors.toList()); } + public static List zip( + List list1, List list2, BiFunction f) { + if (list1.size() != list2.size()) { + throw new IllegalArgumentException("Lists must have the same size to be zipped."); + } + int size = list1.size(); + return IntStream.range(0, size) + .mapToObj(i -> f.apply(list1.get(i), list2.get(i))) + .collect(Collectors.toList()); + } + @SafeVarargs public static List concat(List... lists) { return Arrays.stream(lists).flatMap(List::stream).toList(); diff --git a/core/src/main/java/com/garciat/typeclasses/impl/utils/Rose.java b/core/src/main/java/com/garciat/typeclasses/impl/utils/Rose.java deleted file mode 100644 index ca6b5cd..0000000 --- a/core/src/main/java/com/garciat/typeclasses/impl/utils/Rose.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.garciat.typeclasses.impl.utils; - -import java.util.List; - -public sealed interface Rose { - record Node(A value, List> children) implements Rose {} - - static Rose of(A value, List> children) { - return new Node<>(value, children); - } -} diff --git a/core/src/main/java/com/garciat/typeclasses/impl/utils/Sets.java b/core/src/main/java/com/garciat/typeclasses/impl/utils/Sets.java new file mode 100644 index 0000000..8a4ba2d --- /dev/null +++ b/core/src/main/java/com/garciat/typeclasses/impl/utils/Sets.java @@ -0,0 +1,13 @@ +package com.garciat.typeclasses.impl.utils; + +import java.util.Set; + +public final class Sets { + private Sets() {} + + public static Set difference(Set a, Set b) { + Set result = new java.util.HashSet<>(a); + result.removeAll(b); + return result; + } +} diff --git a/core/src/main/java/com/garciat/typeclasses/processor/Static.java b/core/src/main/java/com/garciat/typeclasses/processor/Static.java index 5f224d0..18429a4 100644 --- a/core/src/main/java/com/garciat/typeclasses/processor/Static.java +++ b/core/src/main/java/com/garciat/typeclasses/processor/Static.java @@ -8,7 +8,12 @@ public final class Static { private Static() {} - public record Method(ExecutableElement java) {} + public record Method(ExecutableElement java) { + @Override + public String toString() { + return java.getEnclosingElement().getSimpleName() + "." + java.getSimpleName(); + } + } public record Var(TypeVariable java) { @Override diff --git a/core/src/main/java/com/garciat/typeclasses/processor/StaticWitnessSystem.java b/core/src/main/java/com/garciat/typeclasses/processor/StaticWitnessSystem.java index 80379c9..c8d9de7 100644 --- a/core/src/main/java/com/garciat/typeclasses/processor/StaticWitnessSystem.java +++ b/core/src/main/java/com/garciat/typeclasses/processor/StaticWitnessSystem.java @@ -2,30 +2,32 @@ import static com.garciat.typeclasses.impl.utils.Streams.isInstanceOf; +import com.garciat.typeclasses.api.Lazy; +import com.garciat.typeclasses.api.Out; import com.garciat.typeclasses.api.TypeClass; import com.garciat.typeclasses.api.hkt.TApp; import com.garciat.typeclasses.api.hkt.TPar; import com.garciat.typeclasses.api.hkt.TagBase; -import com.garciat.typeclasses.impl.Match; import com.garciat.typeclasses.impl.ParsedType; import com.garciat.typeclasses.impl.ParsedType.App; import com.garciat.typeclasses.impl.ParsedType.ArrayOf; import com.garciat.typeclasses.impl.ParsedType.Const; import com.garciat.typeclasses.impl.ParsedType.Primitive; +import com.garciat.typeclasses.impl.ParsedType.TyParam; import com.garciat.typeclasses.impl.ParsedType.Var; import com.garciat.typeclasses.impl.ParsedType.Wildcard; import com.garciat.typeclasses.impl.Resolution; import com.garciat.typeclasses.impl.WitnessConstructor; import com.garciat.typeclasses.impl.utils.Either; +import com.garciat.typeclasses.impl.utils.Lists; import com.garciat.typeclasses.impl.utils.Maybe; import com.garciat.typeclasses.impl.utils.Pair; -import com.garciat.typeclasses.impl.utils.Rose; import java.util.List; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.Modifier; import javax.lang.model.element.Parameterizable; import javax.lang.model.element.TypeElement; -import javax.lang.model.element.VariableElement; +import javax.lang.model.element.TypeParameterElement; import javax.lang.model.type.ArrayType; import javax.lang.model.type.DeclaredType; import javax.lang.model.type.PrimitiveType; @@ -38,14 +40,14 @@ private StaticWitnessSystem() {} public static Either< Resolution.Failure, - Rose>> + Resolution.Result> resolve(TypeMirror target) { - return Resolution.resolve(StaticWitnessSystem::findWitnesses, Rose::of, parse(target)); + return Resolution.resolve(StaticWitnessSystem::findWitnesses, parse(target)); } private static List> - findWitnesses(ParsedType.Const target) { - return target.repr().java().getEnclosedElements().stream() + findWitnesses(Static.Const target) { + return target.java().getEnclosedElements().stream() .flatMap(isInstanceOf(ExecutableElement.class)) .flatMap(method -> parseWitnessConstructor(method).stream()) .toList(); @@ -61,10 +63,7 @@ private StaticWitnessSystem() {} new Static.Method(method), witnessAnn.overlap(), typeParams(method), - method.getParameters().stream() - .map(VariableElement::asType) - .map(StaticWitnessSystem::parse) - .toList(), + Lists.map(method.getParameters(), p -> parse(p.asType())), parse(method.getReturnType()))); } else { @@ -74,17 +73,23 @@ private StaticWitnessSystem() {} private static ParsedType parse(TypeMirror type) { return switch (type) { - case TypeVariable tv -> new Var<>(new Static.Var(tv)); + case TypeVariable tv -> new Var<>(typeParam((TypeParameterElement) tv.asElement())); case ArrayType at -> new ArrayOf<>(parse(at.getComponentType())); case PrimitiveType pt -> new Primitive<>(new Static.Prim(pt)); case DeclaredType dt when parseTagType(dt) instanceof Maybe.Just(var realType) -> constType(realType); case DeclaredType dt when parseAppType(dt) instanceof Maybe.Just(Pair(var fun, var arg)) -> new App<>(parse(fun), parse(arg)); - case DeclaredType dt -> - dt.getTypeArguments().stream() - .map(StaticWitnessSystem::parse) - .reduce(constType(erasure(dt)), App::new); + case DeclaredType dt when parseLazyType(dt) instanceof Maybe.Just(var under) -> + new ParsedType.Lazy<>(parse(under)); + case DeclaredType dt -> { + Const decl = constType(erasure(dt)); + + List> args = + dt.getTypeArguments().stream().map(StaticWitnessSystem::parse).toList(); + + yield Lists.zip(decl.typeParams(), args, TyParam::wrapOut).stream().reduce(decl, App::new); + } case WildcardType _ -> new Wildcard<>(); default -> throw new IllegalArgumentException("Unsupported type: " + type); }; @@ -94,13 +99,23 @@ private static Const constType(TypeElemen return new Const<>(new Static.Const(typeElement), typeParams(typeElement)); } - private static List> typeParams(Parameterizable tp) { - return tp.getTypeParameters().stream() - .map( - tpe -> - new Var( - new Static.Var((TypeVariable) tpe.asType()))) - .toList(); + private static List> typeParams(Parameterizable tp) { + return Lists.map(tp.getTypeParameters(), StaticWitnessSystem::typeParam); + } + + private static TyParam typeParam(TypeParameterElement element) { + return new TyParam<>( + new Static.Var((TypeVariable) element.asType()), element.getAnnotation(Out.class) != null); + } + + private static Maybe parseLazyType(DeclaredType t) { + if (t.asElement() instanceof TypeElement te + && te.getQualifiedName().contentEquals(Lazy.class.getName()) + && t.getTypeArguments().size() == 1) { + return Maybe.just(t.getTypeArguments().getFirst()); + } else { + return Maybe.nothing(); + } } private static Maybe parseTagType(DeclaredType t) { diff --git a/core/src/main/java/com/garciat/typeclasses/runtime/Runtime.java b/core/src/main/java/com/garciat/typeclasses/runtime/Runtime.java index eecea57..1d4f8a9 100644 --- a/core/src/main/java/com/garciat/typeclasses/runtime/Runtime.java +++ b/core/src/main/java/com/garciat/typeclasses/runtime/Runtime.java @@ -3,7 +3,12 @@ public final class Runtime { private Runtime() {} - public record Method(java.lang.reflect.Method java) {} + public record Method(java.lang.reflect.Method java) { + @Override + public String toString() { + return java.getDeclaringClass().getSimpleName() + "." + java.getName(); + } + } public record Var(java.lang.reflect.TypeVariable java) { @Override diff --git a/core/src/main/java/com/garciat/typeclasses/runtime/RuntimeWitnessSystem.java b/core/src/main/java/com/garciat/typeclasses/runtime/RuntimeWitnessSystem.java index 436b927..a702f7c 100644 --- a/core/src/main/java/com/garciat/typeclasses/runtime/RuntimeWitnessSystem.java +++ b/core/src/main/java/com/garciat/typeclasses/runtime/RuntimeWitnessSystem.java @@ -1,20 +1,22 @@ package com.garciat.typeclasses.runtime; +import com.garciat.typeclasses.api.Out; import com.garciat.typeclasses.api.TypeClass; import com.garciat.typeclasses.api.hkt.TApp; import com.garciat.typeclasses.api.hkt.TPar; import com.garciat.typeclasses.api.hkt.TagBase; -import com.garciat.typeclasses.impl.Match; import com.garciat.typeclasses.impl.ParsedType; import com.garciat.typeclasses.impl.ParsedType.App; import com.garciat.typeclasses.impl.ParsedType.ArrayOf; import com.garciat.typeclasses.impl.ParsedType.Const; import com.garciat.typeclasses.impl.ParsedType.Primitive; +import com.garciat.typeclasses.impl.ParsedType.TyParam; import com.garciat.typeclasses.impl.ParsedType.Var; import com.garciat.typeclasses.impl.ParsedType.Wildcard; import com.garciat.typeclasses.impl.Resolution; import com.garciat.typeclasses.impl.WitnessConstructor; import com.garciat.typeclasses.impl.utils.Either; +import com.garciat.typeclasses.impl.utils.Lists; import com.garciat.typeclasses.impl.utils.Maybe; import com.garciat.typeclasses.impl.utils.Pair; import java.lang.reflect.GenericArrayType; @@ -27,24 +29,20 @@ import java.lang.reflect.WildcardType; import java.util.Arrays; import java.util.List; -import java.util.function.BiFunction; public final class RuntimeWitnessSystem { private RuntimeWitnessSystem() {} - public static - Either, R> - resolve( - Type type, - BiFunction< - Match, List, R> - build) { - return Resolution.resolve(RuntimeWitnessSystem::findWitnesses, build, parse(type)); + public static Either< + Resolution.Failure, + Resolution.Result> + resolve(Type type) { + return Resolution.resolve(RuntimeWitnessSystem::findWitnesses, parse(type)); } private static List> - findWitnesses(ParsedType.Const target) { - return Arrays.stream(target.repr().java().getDeclaredMethods()) + findWitnesses(Runtime.Const target) { + return Arrays.stream(target.java().getDeclaredMethods()) .flatMap(m -> parseWitnessConstructor(m).stream()) .toList(); } @@ -68,20 +66,26 @@ private RuntimeWitnessSystem() {} } } - private static ParsedType parse(Type java) { + public static ParsedType parse(Type java) { return switch (java) { case Class tag when parseTagType(tag) instanceof Maybe.Just(var tagged) -> constType(tagged); case Class arr when arr.isArray() -> new ArrayOf<>(parse(arr.getComponentType())); case Class prim when prim.isPrimitive() -> new Primitive<>(new Runtime.Prim(prim)); case Class c -> constType(c); - case TypeVariable v -> new Var<>(new Runtime.Var(v)); + case TypeVariable v -> new Var<>(typeParam(v)); case ParameterizedType p when parseAppType(p) instanceof Maybe.Just(Pair(var fun, var arg)) -> new App<>(parse(fun), parse(arg)); - case ParameterizedType p -> - Arrays.stream(p.getActualTypeArguments()) - .map(RuntimeWitnessSystem::parse) - .reduce(parse(p.getRawType()), App::new); + case ParameterizedType p when parseLazyType(p) instanceof Maybe.Just(var under) -> + new ParsedType.Lazy<>(parse(under)); + case ParameterizedType p -> { + Const decl = constType((Class) p.getRawType()); + + List> args = + Arrays.stream(p.getActualTypeArguments()).map(RuntimeWitnessSystem::parse).toList(); + + yield Lists.zip(decl.typeParams(), args, TyParam::wrapOut).stream().reduce(decl, App::new); + } case GenericArrayType a -> new ArrayOf<>(parse(a.getGenericComponentType())); case WildcardType _ -> new Wildcard<>(); default -> throw new IllegalArgumentException("Unsupported type: " + java); @@ -92,16 +96,27 @@ private static Const constType(Class(new Runtime.Const(tagged), typeParams(tagged)); } - private static List> typeParams( - GenericDeclaration cls) { - return Arrays.stream(cls.getTypeParameters()) - .map(t -> new Var(new Runtime.Var(t))) - .toList(); + private static List> typeParams(GenericDeclaration cls) { + return Arrays.stream(cls.getTypeParameters()).map(RuntimeWitnessSystem::typeParam).toList(); + } + + private static TyParam typeParam(TypeVariable t) { + return new TyParam<>(new Runtime.Var(t), t.isAnnotationPresent(Out.class)); + } + + private static Maybe parseLazyType(ParameterizedType t) { + return switch (t.getRawType()) { + case Class raw when raw.equals(com.garciat.typeclasses.api.Lazy.class) -> + Maybe.just(t.getActualTypeArguments()[0]); + default -> Maybe.nothing(); + }; } private static Maybe> parseTagType(Class c) { return switch (c.getEnclosingClass()) { - case Class enclosing when c.getSuperclass().equals(TagBase.class) -> Maybe.just(enclosing); + case Class enclosing + when c.getSuperclass() instanceof Class sup && sup.equals(TagBase.class) -> + Maybe.just(enclosing); case null, default -> Maybe.nothing(); }; } diff --git a/core/src/test/java/com/garciat/typeclasses/examples/Example1.java b/core/src/test/java/com/garciat/typeclasses/examples/Example1.java new file mode 100644 index 0000000..0b54c8f --- /dev/null +++ b/core/src/test/java/com/garciat/typeclasses/examples/Example1.java @@ -0,0 +1,47 @@ +package com.garciat.typeclasses.examples; + +import static com.garciat.typeclasses.TypeClasses.witness; +import static org.assertj.core.api.Assertions.assertThat; + +import com.garciat.typeclasses.api.Ty; +import com.garciat.typeclasses.api.TypeClass; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; + +public class Example1 { + @Test + void main() { + Pair> value = new Pair<>(1, List.of(2, 3, 4)); + + String s = Show.show(witness(new Ty<>() {}), value); + + assertThat(s).isEqualTo("(1, [2, 3, 4])"); + } + + @TypeClass + public interface Show { + String show(T value); + + static String show(Show showT, T value) { + return showT.show(value); + } + + @TypeClass.Witness + static Show integerShow() { + return i -> Integer.toString(i); + } + + @TypeClass.Witness + static Show> listShow(Show showA) { + return listA -> listA.stream().map(showA::show).collect(Collectors.joining(", ", "[", "]")); + } + } + + public record Pair(A first, B second) { + @TypeClass.Witness + public static Show> pairShow(Show showA, Show showB) { + return pair -> "(" + showA.show(pair.first()) + ", " + showB.show(pair.second()) + ")"; + } + } +} diff --git a/core/src/test/java/com/garciat/typeclasses/examples/Example2.java b/core/src/test/java/com/garciat/typeclasses/examples/Example2.java new file mode 100644 index 0000000..2b86ed1 --- /dev/null +++ b/core/src/test/java/com/garciat/typeclasses/examples/Example2.java @@ -0,0 +1,50 @@ +package com.garciat.typeclasses.examples; + +import static com.garciat.typeclasses.TypeClasses.witness; +import static org.assertj.core.api.Assertions.assertThat; + +import com.garciat.typeclasses.api.Ty; +import com.garciat.typeclasses.api.TypeClass; +import com.garciat.typeclasses.examples.Example2.TList.TCons; +import com.garciat.typeclasses.examples.Example2.TList.TNil; +import org.junit.jupiter.api.Test; + +/// Based on: +/// +/// ```haskell +/// class In (xs :: [k]) (x :: k) +/// +/// instance In (x ': xs) x +/// +/// instance {-# OVERLAPPABLE #-} In xs y => In (x ': xs) y +/// +/// example :: In '[Int, Bool, Char] Bool => () +/// example = () +/// ``` +public class Example2 { + @Test + void main() { + In>>, Short> w = witness(new Ty<>() {}); + + assertThat(w).isNotNull(); + } + + public interface TList> { + record TNil() implements TList {} + + record TCons>() implements TList> {} + } + + @TypeClass + public interface In, Y> { + @TypeClass.Witness + static > In, X> here() { + return new In<>() {}; + } + + @TypeClass.Witness(overlap = TypeClass.Witness.Overlap.OVERLAPPABLE) + static , Y> In, Y> there(In there) { + return new In<>() {}; + } + } +} diff --git a/core/src/test/java/com/garciat/typeclasses/examples/Example3.java b/core/src/test/java/com/garciat/typeclasses/examples/Example3.java new file mode 100644 index 0000000..5f65826 --- /dev/null +++ b/core/src/test/java/com/garciat/typeclasses/examples/Example3.java @@ -0,0 +1,145 @@ +package com.garciat.typeclasses.examples; + +import static com.garciat.typeclasses.TypeClasses.witness; +import static org.assertj.core.api.Assertions.assertThat; + +import com.garciat.typeclasses.api.Lazy; +import com.garciat.typeclasses.api.Out; +import com.garciat.typeclasses.api.Ty; +import com.garciat.typeclasses.api.TypeClass; +import com.garciat.typeclasses.examples.Example3.TyRep.K1; +import com.garciat.typeclasses.examples.Example3.TyRep.Prod; +import com.garciat.typeclasses.examples.Example3.TyRep.Sum; +import com.garciat.typeclasses.examples.Example3.TyRep.Sum.L1; +import com.garciat.typeclasses.examples.Example3.TyRep.Sum.R1; +import java.util.List; +import org.junit.jupiter.api.Test; + +public final class Example3 { + @Test + void example() { + Tree.Node tree = + new Tree.Node<>( + new Tree.Leaf<>(1), new Tree.Node<>(new Tree.Leaf<>(2), new Tree.Leaf<>(3))); + + ToJson> toJsonTree = witness(new Ty<>() {}); + + assertThat(toJsonTree.toJson(tree)).isEqualTo(array(value(1), array(value(2), value(3)))); + } + + private static JsonValue array(JsonValue... values) { + return new JsonValue.JsonArray(List.of(values)); + } + + private static JsonValue value(int value) { + return new JsonValue.JsonInteger(value); + } + + @TypeClass + public interface Generic { + Rep from(A a); + + A to(Rep rep); + } + + public interface TyRep { + record K1(A value) {} + + sealed interface Sum { + record L1(A left) implements Sum {} + + record R1(B right) implements Sum {} + } + + record Prod(A first, B second) {} + } + + public interface JsonValue { + record JsonString(String value) implements JsonValue {} + + record JsonInteger(int value) implements JsonValue {} + + record JsonObject(List props) implements JsonValue {} + + record JsonArray(List values) implements JsonValue {} + + record Prop(String key, JsonValue value) {} + } + + @TypeClass + public interface ToJson { + JsonValue toJson(A a); + + @TypeClass.Witness + static ToJson toJsonInteger() { + return JsonValue.JsonInteger::new; + } + } + + @TypeClass + public interface ToJsonGeneric { + JsonValue toJson(Rep rep); + + static ToJsonGeneric toJsonGeneric(Generic generic, ToJson toJsonA) { + return rep -> toJsonA.toJson(generic.to(rep)); + } + + @TypeClass.Witness + static ToJsonGeneric> k1(Lazy> toJsonA) { + return rep -> toJsonA.get().toJson(rep.value()); + } + + @TypeClass.Witness + static ToJsonGeneric> prod( + ToJsonGeneric toJsonA, ToJsonGeneric toJsonB) { + return rep -> + new JsonValue.JsonArray( + List.of(toJsonA.toJson(rep.first()), toJsonB.toJson(rep.second()))); + } + + @TypeClass.Witness + static ToJsonGeneric> sum(ToJsonGeneric toJsonA, ToJsonGeneric toJsonB) { + return rep -> + switch (rep) { + case L1(var value) -> toJsonA.toJson(value); + case R1(var value) -> toJsonB.toJson(value); + }; + } + } + + public sealed interface Tree { + record Leaf(A value) implements Tree {} + + record Node(Tree left, Tree right) implements Tree {} + + @TypeClass.Witness + static ToJson> toJson( + Generic, Rep> generic, ToJsonGeneric toJsonGeneric) { + return tree -> toJsonGeneric.toJson(generic.from(tree)); + } + + @TypeClass.Witness + static Generic, Sum, Prod>, K1>>>> generic() { + return new Generic<>() { + @Override + public Sum, Prod>, K1>>> from(Tree tree) { + return switch (tree) { + case Leaf leaf -> + new L1, Prod>, K1>>>(new K1<>(leaf.value)); + case Node node -> + new R1, Prod>, K1>>>( + new Prod<>(new K1<>(node.left), new K1<>(node.right))); + }; + } + + @Override + public Tree to(Sum, Prod>, K1>>> rep) { + return switch (rep) { + case L1(K1(var value)) -> new Leaf<>(value); + case R1(Prod(K1(var left), K1(var right))) -> new Node<>(left, right); + }; + } + }; + } + } +} diff --git a/core/src/test/java/com/garciat/typeclasses/examples/Example4.java b/core/src/test/java/com/garciat/typeclasses/examples/Example4.java new file mode 100644 index 0000000..8547c74 --- /dev/null +++ b/core/src/test/java/com/garciat/typeclasses/examples/Example4.java @@ -0,0 +1,55 @@ +package com.garciat.typeclasses.examples; + +import static com.garciat.typeclasses.TypeClasses.witness; +import static org.assertj.core.api.Assertions.assertThat; + +import com.garciat.typeclasses.api.Out; +import com.garciat.typeclasses.api.Ty; +import com.garciat.typeclasses.api.TypeClass; +import java.util.List; +import org.junit.jupiter.api.Test; + +/// Based on: +/// +/// ```haskell +/// type family ElementOf a where +/// ElementOf [[a]] = ElementOf [a] +/// ElementOf [a] = a +/// +/// class Flatten a where +/// flatten :: a -> [ElementOf a] +/// +/// instance Flatten [a] where +/// flatten x = x +/// +/// instance {-# OVERLAPPING #-} Flatten [a] => Flatten [[a]] where +/// flatten x = flatten (concat x) +/// ``` +/// +/// From: "An introduction to typeclass metaprogramming" by Alexis King +public class Example4 { + @Test + void main() { + Flatten>, String> e1 = witness(new Ty<>() {}); + Flatten, String> e2 = witness(new Ty<>() {}); + + assertThat(e1.flatten(List.of(List.of("a", "b"), List.of("c")))) + .isEqualTo(List.of("a", "b", "c")); + assertThat(e2.flatten(List.of("a", "b", "c"))).isEqualTo(List.of("a", "b", "c")); + } + + @TypeClass + public interface Flatten { + List flatten(A list); + + @TypeClass.Witness + static Flatten, A> here() { + return list -> list; + } + + @TypeClass.Witness(overlap = TypeClass.Witness.Overlap.OVERLAPPING) + static Flatten>, R> there(Flatten, R> e) { + return list -> list.stream().flatMap(innerList -> e.flatten(innerList).stream()).toList(); + } + } +} diff --git a/core/src/test/java/com/garciat/typeclasses/examples/Example5.java b/core/src/test/java/com/garciat/typeclasses/examples/Example5.java new file mode 100644 index 0000000..c11aea3 --- /dev/null +++ b/core/src/test/java/com/garciat/typeclasses/examples/Example5.java @@ -0,0 +1,98 @@ +package com.garciat.typeclasses.examples; + +import static com.garciat.typeclasses.TypeClasses.witness; +import static org.assertj.core.api.Assertions.assertThat; + +import com.garciat.typeclasses.api.Out; +import com.garciat.typeclasses.api.Ty; +import com.garciat.typeclasses.api.TypeClass; +import com.garciat.typeclasses.examples.Example5.Nat.S; +import com.garciat.typeclasses.examples.Example5.Nat.Z; +import com.garciat.typeclasses.impl.utils.Unit; +import org.junit.jupiter.api.Test; + +/// Based on: +/// +/// ```haskell +/// data Nat = Z | S Nat +/// +/// class ReifyNat (a :: Nat) where +/// reifyNat :: Natural +/// +/// instance ReifyNat 'Z where +/// reifyNat = 0 +/// +/// instance ReifyNat a => ReifyNat ('S a) where +/// reifyNat = 1 + reifyNat @a +/// ``` +public class Example5 { + @Test + void test1() { + ReifyNat>>> reifier = witness(new Ty<>() {}); + + assertThat(reifier.reify()).isEqualTo(3); + } + + @Test + void test2() { + NatAdd>, S>>, S>>>>> adder = witness(new Ty<>() {}); + + assertThat(adder).isNotNull(); + } + + @Test + void test3() { + ReifyNatAdd>, S>>> reifyAdd = witness(new Ty<>() {}); + + assertThat(reifyAdd.reify()).isEqualTo(5); + } + + public sealed interface Nat> { + record Z() implements Nat {} + + // Note that we don't store the predecessor! + record S>() implements Nat> {} + } + + @TypeClass + public interface ReifyNat> { + int reify(); + + @TypeClass.Witness + static ReifyNat reifyZ() { + return () -> 0; + } + + @TypeClass.Witness + static > ReifyNat> reifyS(ReifyNat rn) { + return () -> 1 + rn.reify(); + } + } + + @TypeClass + public interface NatAdd { + Unit trivial(); + + @TypeClass.Witness + static > NatAdd addZ() { + return Unit::unit; + } + + @TypeClass.Witness + static , B extends Nat, C extends Nat> NatAdd, B, S> addS( + NatAdd prev) { + return Unit::unit; + } + } + + @TypeClass + public interface ReifyNatAdd { + int reify(); + + @TypeClass.Witness + static , B extends Nat, C extends Nat> ReifyNatAdd reifyAddS( + NatAdd addAB, ReifyNat reifyC) { + return reifyC::reify; + } + } +} diff --git a/core/src/test/java/com/garciat/typeclasses/examples/Example6.java b/core/src/test/java/com/garciat/typeclasses/examples/Example6.java new file mode 100644 index 0000000..519a940 --- /dev/null +++ b/core/src/test/java/com/garciat/typeclasses/examples/Example6.java @@ -0,0 +1,132 @@ +package com.garciat.typeclasses.examples; + +import static com.garciat.typeclasses.TypeClasses.witness; +import static org.assertj.core.api.Assertions.assertThat; + +import com.garciat.typeclasses.api.Out; +import com.garciat.typeclasses.api.Ty; +import com.garciat.typeclasses.api.TypeClass; +import com.garciat.typeclasses.impl.utils.Unit; +import org.junit.jupiter.api.Test; + +public class Example6 { + @Test + void test1() { + var expr = new Expr.Add<>(new Expr.Int(1), new Expr.Add<>(new Expr.Int(2), new Expr.Int(3))); + + boolean result = ReifiedContainsVoid.containsVoid(witness(new Ty<>() {}), expr); + + assertThat(result).isFalse(); + } + + @Test + void test2() { + var expr = new Expr.Add<>(new Expr.Int(1), new Expr.Add<>(new Expr.Void(), new Expr.Int(3))); + + boolean result = ReifiedContainsVoid.containsVoid(witness(new Ty<>() {}), expr); + + assertThat(result).isTrue(); + } + + public sealed interface Fact { + record True() implements Fact {} + + record False() implements Fact {} + } + + @TypeClass + public interface FactOr { + Unit trivial(); + + @TypeClass.Witness(overlap = TypeClass.Witness.Overlap.OVERLAPPING) + static FactOr here() { + return Unit::unit; + } + + @TypeClass.Witness + static FactOr notHere() { + return Unit::unit; + } + } + + @TypeClass + public interface FactNot { + Unit trivial(); + + @TypeClass.Witness + static FactNot factNotTrue() { + return Unit::unit; + } + + @TypeClass.Witness + static FactNot factNotFalse() { + return Unit::unit; + } + } + + @TypeClass + public interface ReifiedFact { + boolean reify(); + + @TypeClass.Witness + static ReifiedFact reifiedTrue() { + return () -> true; + } + + @TypeClass.Witness + static ReifiedFact reifiedFalse() { + return () -> false; + } + } + + public sealed interface Expr> { + record Void() implements Expr {} + + record Int(int value) implements Expr {} + + record Add, T2 extends Expr>(Expr left, Expr right) + implements Expr> {} + } + + @TypeClass + public interface ContainsVoid, @Out R> { + Unit trivial(); + + @TypeClass.Witness + static ContainsVoid here() { + return Unit::unit; + } + + @TypeClass.Witness + static ContainsVoid notHereInt() { + return Unit::unit; + } + + @TypeClass.Witness + static < + T1 extends Expr, + T2 extends Expr, + FL extends Fact, + FR extends Fact, + F extends Fact> + ContainsVoid, F> add( + ContainsVoid left, ContainsVoid right, FactOr factOr) { + return Unit::unit; + } + } + + @TypeClass + public interface ReifiedContainsVoid> { + boolean reify(); + + static > boolean containsVoid(ReifiedContainsVoid containsVoid, E ignore) { + return containsVoid.reify(); + } + + @TypeClass.Witness + static , F> ReifiedContainsVoid reifiedHere( + ContainsVoid here, ReifiedFact fact) { + return fact::reify; + } + } +} diff --git a/core/src/test/java/com/garciat/typeclasses/processor/WitnessResolutionCheckerTest.java b/core/src/test/java/com/garciat/typeclasses/processor/WitnessResolutionCheckerTest.java index 0a04c98..8db23b2 100644 --- a/core/src/test/java/com/garciat/typeclasses/processor/WitnessResolutionCheckerTest.java +++ b/core/src/test/java/com/garciat/typeclasses/processor/WitnessResolutionCheckerTest.java @@ -20,7 +20,15 @@ public class WitnessResolutionCheckerTest { @Nullable @TempDir Path tempDir; @ParameterizedTest - @ValueSource(strings = {"Example1.java", "Example2.java"}) + @ValueSource( + strings = { + "Example1.java", + "Example2.java", + "Example3.java", + "Example4.java", + "Example5.java", + "Example6.java", + }) public void checkExample(String fileName) throws IOException { requireNonNull(tempDir); @@ -34,7 +42,7 @@ public void checkExample(String fileName) throws IOException { fileManager.setLocation(StandardLocation.SOURCE_OUTPUT, List.of(tempDir.toFile())); var files = new java.util.ArrayList(); - files.add(new File("src/test/java/com/garciat/typeclasses/processor/examples/" + fileName)); + files.add(new File("src/test/java/com/garciat/typeclasses/examples/" + fileName)); var compilationUnits = fileManager.getJavaFileObjectsFromFiles(files); diff --git a/core/src/test/java/com/garciat/typeclasses/processor/examples/Example1.java b/core/src/test/java/com/garciat/typeclasses/processor/examples/Example1.java deleted file mode 100644 index 4c29c65..0000000 --- a/core/src/test/java/com/garciat/typeclasses/processor/examples/Example1.java +++ /dev/null @@ -1,44 +0,0 @@ -package com.garciat.typeclasses.processor.examples; - -import static com.garciat.typeclasses.TypeClasses.witness; - -import com.garciat.typeclasses.api.Ty; -import com.garciat.typeclasses.api.TypeClass; -import java.util.List; -import java.util.stream.Collectors; - -@TypeClass -interface Show { - String show(T value); - - static String show(Show showT, T value) { - return showT.show(value); - } - - @TypeClass.Witness - static Show integerShow() { - return i -> Integer.toString(i); - } - - @TypeClass.Witness - static Show> listShow(Show showA) { - return listA -> listA.stream().map(showA::show).collect(Collectors.joining(", ", "[", "]")); - } -} - -record Pair(A first, B second) { - @TypeClass.Witness - public static Show> pairShow(Show showA, Show showB) { - return pair -> "(" + showA.show(pair.first()) + ", " + showB.show(pair.second()) + ")"; - } -} - -public class Example1 { - void main() { - Pair> value = new Pair<>(1, List.of(2, 3, 4)); - - String s = Show.show(witness(new Ty<>() {}), value); - - System.out.println(s); - } -} diff --git a/core/src/test/java/com/garciat/typeclasses/processor/examples/Example2.java b/core/src/test/java/com/garciat/typeclasses/processor/examples/Example2.java deleted file mode 100644 index 2eeda7e..0000000 --- a/core/src/test/java/com/garciat/typeclasses/processor/examples/Example2.java +++ /dev/null @@ -1,45 +0,0 @@ -package com.garciat.typeclasses.processor.examples; - -import static com.garciat.typeclasses.TypeClasses.witness; - -import com.garciat.typeclasses.api.Ty; -import com.garciat.typeclasses.api.TypeClass; -import com.garciat.typeclasses.processor.examples.TList.TCons; -import com.garciat.typeclasses.processor.examples.TList.TNil; - -interface TList> { - record TNil() implements TList {} - - record TCons>() implements TList> {} -} - -@TypeClass -interface In, Y> { - @TypeClass.Witness - static > In, X> here() { - return new In<>() {}; - } - - @TypeClass.Witness(overlap = TypeClass.Witness.Overlap.OVERLAPPABLE) - static , Y> In, Y> there(In there) { - return new In<>() {}; - } -} - -/// Based on: -/// -/// ```haskell -/// class In (xs :: [k]) (x :: k) -/// -/// instance In (x ': xs) x -/// -/// instance {-# OVERLAPPABLE #-} In xs y => In (x ': xs) y -/// -/// example :: In '[Int, Bool, Char] Bool => () -/// example = () -/// ``` -public class Example2 { - void main() { - In>>, Short> _ = witness(new Ty<>() {}); - } -}