diff --git a/futur-api/src/main/java/dev/tommyjs/futur/promise/AbstractPromise.java b/futur-api/src/main/java/dev/tommyjs/futur/promise/AbstractPromise.java index aa64c85..149df3c 100644 --- a/futur-api/src/main/java/dev/tommyjs/futur/promise/AbstractPromise.java +++ b/futur-api/src/main/java/dev/tommyjs/futur/promise/AbstractPromise.java @@ -12,6 +12,8 @@ import org.slf4j.Logger; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; public abstract class AbstractPromise implements Promise { @@ -31,17 +33,40 @@ public abstract class AbstractPromise implements Promise { } } - protected void runCompleter(@NotNull CompletablePromise promise, @NotNull ExceptionalRunnable completer) { + protected V supplySafe(@NotNull ExceptionalSupplier supplier, @NotNull Function handler) { try { - completer.run(); - } catch (Error e) { - promise.completeExceptionally(e); - throw e; + return supplier.get(); + } catch (Error error) { + // Rethrow error so the Thread can shut down + throw error; } catch (Throwable e) { - promise.completeExceptionally(e); + return handler.apply(e); } } + protected void runSafe(@NotNull ExceptionalRunnable runnable, @NotNull Consumer handler) { + try { + runnable.run(); + } catch (Error error) { + handler.accept(error); + // Rethrow error so the Thread can shut down + throw error; + } catch (Throwable e) { + handler.accept(e); + } + } + + protected void runCompleter(@NotNull CompletablePromise promise, @NotNull ExceptionalRunnable completer) { + runSafe(completer, promise::completeExceptionally); + } + + protected V useCompletion(Supplier unresolved, Function completed, Function failed) { + PromiseCompletion completion = getCompletion(); + if (completion == null) return unresolved.get(); + else if (completion.isSuccess()) return completed.apply(completion.getResult()); + else return failed.apply(completion.getException()); + } + protected @NotNull Runnable createCompleter(T result, @NotNull CompletablePromise promise, @NotNull ExceptionalFunction completer) { return () -> { @@ -67,14 +92,7 @@ public abstract class AbstractPromise implements Promise { } protected void callListenerNow(PromiseListener listener, PromiseCompletion res) { - try { - listener.handle(res); - } catch (Error e) { - getLogger().error("Error caught in promise listener", e); - throw e; - } catch (Throwable e) { - getLogger().error("Exception caught in promise listener", e); - } + runSafe(() -> listener.handle(res), e -> getLogger().error("Exception caught in promise listener", e)); } protected void callListenerAsyncLastResort(PromiseListener listener, PromiseCompletion completion) { @@ -84,16 +102,27 @@ public abstract class AbstractPromise implements Promise { } } + protected T joinCompletionChecked() throws ExecutionException { + PromiseCompletion completion = getCompletion(); + assert completion != null; + if (completion.isSuccess()) return completion.getResult(); + throw new ExecutionException(completion.getException()); + } + + protected T joinCompletionUnchecked() { + PromiseCompletion completion = getCompletion(); + assert completion != null; + if (completion.isSuccess()) return completion.getResult(); + throw new CompletionException(completion.getException()); + } + @Override public @NotNull Promise fork() { - PromiseCompletion completion = getCompletion(); - if (completion == null) { - CompletablePromise fork = getFactory().unresolved(); - PromiseUtil.propagateCompletion(this, fork); - return fork; - } else { - return this; - } + if (isCompleted()) return this; + + CompletablePromise fork = getFactory().unresolved(); + PromiseUtil.propagateCompletion(this, fork); + return fork; } @Override @@ -119,68 +148,61 @@ public abstract class AbstractPromise implements Promise { @Override public @NotNull Promise thenApply(@NotNull ExceptionalFunction task) { - PromiseCompletion completion = getCompletion(); - if (completion == null) { - CompletablePromise promise = createLinked(); - addDirectListener( - res -> createCompleter(res, promise, task).run(), - promise::completeExceptionally - ); + return useCompletion( + () -> { + CompletablePromise promise = createLinked(); + addDirectListener( + res -> createCompleter(res, promise, task).run(), + promise::completeExceptionally + ); - return promise; - } else if (completion.isSuccess()) { - try { - V result = task.apply(completion.getResult()); - return getFactory().resolve(result); - } catch (Exception e) { - return getFactory().error(e); - } - } else { - Throwable ex = completion.getException(); - assert ex != null; - return getFactory().error(ex); - } + return promise; + }, + result -> supplySafe( + () -> getFactory().resolve(task.apply(result)), + getFactory()::error + ), + getFactory()::error + ); } @Override public @NotNull Promise thenCompose(@NotNull ExceptionalFunction> task) { - PromiseCompletion completion = getCompletion(); - if (completion == null) { - CompletablePromise promise = createLinked(); - thenApply(task).addDirectListener( - result -> { - if (result == null) { - promise.complete(null); + return useCompletion( + () -> { + CompletablePromise promise = createLinked(); + thenApply(task).addDirectListener( + result -> { + if (result == null) { + promise.complete(null); + } else { + PromiseUtil.propagateCompletion(result, promise); + PromiseUtil.propagateCancel(promise, result); + } + }, + promise::completeExceptionally + ); + + return promise; + }, + result -> supplySafe( + () -> { + Promise nested = task.apply(result); + if (nested == null) { + return getFactory().resolve(null); + } else if (nested.isCompleted()) { + return nested; } else { - PromiseUtil.propagateCompletion(result, promise); - PromiseUtil.propagateCancel(promise, result); + CompletablePromise promise = createLinked(); + PromiseUtil.propagateCompletion(nested, promise); + PromiseUtil.propagateCancel(promise, nested); + return promise; } }, - promise::completeExceptionally - ); - - return promise; - } else if (completion.isSuccess()) { - try { - Promise result = task.apply(completion.getResult()); - if (result == null) { - return getFactory().resolve(null); - } else if (result.isCompleted()) { - return result; - } else { - CompletablePromise promise = createLinked(); - PromiseUtil.propagateCompletion(result, promise); - PromiseUtil.propagateCancel(promise, result); - return promise; - } - } catch (Exception e) { - return getFactory().error(e); - } - } else { - Throwable ex = completion.getException(); - assert ex != null; - return getFactory().error(ex); - } + getFactory()::error + ), + getFactory()::error + ); } @Override @@ -451,36 +473,38 @@ public abstract class AbstractPromise implements Promise { @Override public @NotNull Promise orDefault(@NotNull ExceptionalFunction function) { - PromiseCompletion completion = getCompletion(); - if (completion == null) { - CompletablePromise promise = createLinked(); - addDirectListener(promise::complete, e -> runCompleter(promise, () -> promise.complete(function.apply(e)))); - return promise; - } else if (completion.isSuccess()) { - return getFactory().resolve(completion.getResult()); - } else { - try { - return getFactory().resolve(function.apply(completion.getException())); - } catch (Exception e) { - return getFactory().error(e); - } - } + return useCompletion( + () -> { + CompletablePromise promise = createLinked(); + addDirectListener(promise::complete, e -> runCompleter(promise, () -> promise.complete(function.apply(e)))); + return promise; + }, + getFactory()::resolve, + getFactory()::error + ); } @Override public @NotNull CompletableFuture toFuture() { - CompletableFuture future = new CompletableFuture<>(); - addDirectListener(future::complete, future::completeExceptionally); - future.whenComplete((_, e) -> { - if (e instanceof CancellationException) { - cancel(); - } - }); + return useCompletion( + () -> { + CompletableFuture future = new CompletableFuture<>(); + addDirectListener(future::complete, future::completeExceptionally); + future.whenComplete((_, e) -> { + if (e instanceof CancellationException) { + cancel(); + } + }); - return future; + return future; + }, + CompletableFuture::completedFuture, + CompletableFuture::failedFuture + ); } private static class DeferredExecutionException extends ExecutionException { + } } diff --git a/futur-api/src/main/java/dev/tommyjs/futur/promise/BasePromise.java b/futur-api/src/main/java/dev/tommyjs/futur/promise/BasePromise.java index f2a1a67..45c8c48 100644 --- a/futur-api/src/main/java/dev/tommyjs/futur/promise/BasePromise.java +++ b/futur-api/src/main/java/dev/tommyjs/futur/promise/BasePromise.java @@ -42,12 +42,6 @@ public abstract class BasePromise extends AbstractPromise this.listeners = Collections.EMPTY_LIST; } - protected T joinCompletion() throws ExecutionException { - PromiseCompletion completion = Objects.requireNonNull(getCompletion()); - if (completion.isSuccess()) return completion.getResult(); - throw new ExecutionException(completion.getException()); - } - protected void handleCompletion(@NotNull PromiseCompletion cmp) { if (!COMPLETION_HANDLE.compareAndSet(this, null, cmp)) return; sync.releaseShared(1); @@ -99,31 +93,36 @@ public abstract class BasePromise extends AbstractPromise @Override public T get() throws InterruptedException, ExecutionException { - sync.acquireSharedInterruptibly(1); - return joinCompletion(); + if (!isCompleted()) { + sync.acquireSharedInterruptibly(1); + } + + return joinCompletionChecked(); } @Override public T get(long time, @NotNull TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { - boolean success = sync.tryAcquireSharedNanos(1, unit.toNanos(time)); - if (!success) { - throw new TimeoutException("Promise stopped waiting after " + time + " " + unit); + if (!isCompleted()) { + boolean success = sync.tryAcquireSharedNanos(1, unit.toNanos(time)); + if (!success) { + throw new TimeoutException("Promise stopped waiting after " + time + " " + unit); + } } - return joinCompletion(); + return joinCompletionChecked(); } @Override public T await() { - try { - sync.acquireSharedInterruptibly(1); - } catch (InterruptedException e) { - throw new RuntimeException(e); + if (!isCompleted()) { + try { + sync.acquireSharedInterruptibly(1); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } } - PromiseCompletion completion = Objects.requireNonNull(getCompletion()); - if (completion.isSuccess()) return completion.getResult(); - throw new CompletionException(completion.getException()); + return joinCompletionUnchecked(); } @Override diff --git a/futur-api/src/main/java/dev/tommyjs/futur/promise/CompletedPromise.java b/futur-api/src/main/java/dev/tommyjs/futur/promise/CompletedPromise.java index ba91627..d13fa2c 100644 --- a/futur-api/src/main/java/dev/tommyjs/futur/promise/CompletedPromise.java +++ b/futur-api/src/main/java/dev/tommyjs/futur/promise/CompletedPromise.java @@ -3,6 +3,7 @@ package dev.tommyjs.futur.promise; import org.jetbrains.annotations.NotNull; import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; public abstract class CompletedPromise extends AbstractPromise { @@ -28,36 +29,34 @@ public abstract class CompletedPromise extends AbstractPromise timeout(long time, @NotNull TimeUnit unit) { + // Promise is already completed so can't time out return this; } @Override public @NotNull Promise maxWaitTime(long time, @NotNull TimeUnit unit) { + // Promise is already completed so can't time out return this; } @Override public void cancel(@NotNull CancellationException exception) { + // Promise is already completed so can't be cancelled } @Override - public T get() { - return null; + public T get() throws ExecutionException { + return joinCompletionChecked(); } @Override - public T get(long timeout, @NotNull TimeUnit unit) { - return null; + public T get(long timeout, @NotNull TimeUnit unit) throws ExecutionException { + return joinCompletionChecked(); } @Override public T await() { - return null; - } - - @Override - public @NotNull Promise fork() { - return this; + return joinCompletionUnchecked(); } @Override