package dev.tommyjs.futur.promise; import dev.tommyjs.futur.executor.PromiseExecutor; import dev.tommyjs.futur.function.ExceptionalConsumer; import dev.tommyjs.futur.function.ExceptionalFunction; import dev.tommyjs.futur.function.ExceptionalRunnable; import dev.tommyjs.futur.function.ExceptionalSupplier; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import org.slf4j.Logger; import java.util.Collection; import java.util.LinkedList; import java.util.Objects; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; public abstract class AbstractPromise implements Promise { private Collection> listeners; private final AtomicReference> completion; private final CountDownLatch latch; private final Lock lock; public AbstractPromise() { this.completion = new AtomicReference<>(); this.latch = new CountDownLatch(1); this.lock = new ReentrantLock(); } protected static void propagateResult(Promise from, Promise to) { from.addDirectListener(to::complete, to::completeExceptionally); } protected static void propagateCancel(Promise from, Promise to) { from.onCancel(to::completeExceptionally); } private @NotNull Runnable createRunnable(T result, @NotNull Promise promise, @NotNull ExceptionalFunction task) { return () -> { if (promise.isCompleted()) return; try { V nextResult = task.apply(result); promise.complete(nextResult); } catch (Throwable e) { promise.completeExceptionally(e); } }; } public abstract @NotNull AbstractPromiseFactory getFactory(); protected @NotNull PromiseExecutor getExecutor() { return getFactory().getExecutor(); } protected @NotNull Logger getLogger() { return getFactory().getLogger(); } @Override public T awaitInterruptibly() throws InterruptedException { this.latch.await(); return joinCompletion(Objects.requireNonNull(getCompletion())); } @Override public T awaitInterruptibly(long timeoutMillis) throws TimeoutException, InterruptedException { boolean success = this.latch.await(timeoutMillis, TimeUnit.MILLISECONDS); if (!success) { throw new TimeoutException("Promise stopped waiting after " + timeoutMillis + "ms"); } return joinCompletion(Objects.requireNonNull(getCompletion())); } @Override public T await() { try { return awaitInterruptibly(); } catch (InterruptedException e) { throw new RuntimeException(e); } } @Override public T await(long timeoutMillis) throws TimeoutException { try { return awaitInterruptibly(timeoutMillis); } catch (InterruptedException e) { throw new RuntimeException(e); } } private T joinCompletion(PromiseCompletion completion) { if (completion.isError()) throw new RuntimeException(completion.getException()); return completion.getResult(); } @Override public @NotNull Promise thenRunSync(@NotNull ExceptionalRunnable task) { return thenApplySync(result -> { task.run(); return null; }); } @Override public @NotNull Promise thenRunDelayedSync(@NotNull ExceptionalRunnable task, long delay, @NotNull TimeUnit unit) { return thenApplyDelayedSync(result -> { task.run(); return null; }, delay, unit); } @Override public @NotNull Promise thenConsumeSync(@NotNull ExceptionalConsumer task) { return thenApplySync(result -> { task.accept(result); return null; }); } @Override public @NotNull Promise thenConsumeDelayedSync(@NotNull ExceptionalConsumer task, long delay, @NotNull TimeUnit unit) { return thenApplyDelayedSync(result -> { task.accept(result); return null; }, delay, unit); } @Override public @NotNull Promise thenSupplySync(@NotNull ExceptionalSupplier task) { return thenApplySync(result -> task.get()); } @Override public @NotNull Promise thenSupplyDelayedSync(@NotNull ExceptionalSupplier task, long delay, @NotNull TimeUnit unit) { return thenApplyDelayedSync(result -> task.get(), delay, unit); } @Override public @NotNull Promise thenApplySync(@NotNull ExceptionalFunction task) { Promise promise = getFactory().unresolved(); addDirectListener( res -> { try { Runnable runnable = createRunnable(res, promise, task); F future = getExecutor().runSync(runnable); promise.onCancel((e) -> getExecutor().cancel(future)); } catch (RejectedExecutionException e) { promise.completeExceptionally(e); } }, promise::completeExceptionally ); propagateCancel(promise, this); return promise; } @Override public @NotNull Promise thenApplyDelayedSync(@NotNull ExceptionalFunction task, long delay, @NotNull TimeUnit unit) { Promise promise = getFactory().unresolved(); addDirectListener( res -> { try { Runnable runnable = createRunnable(res, promise, task); F future = getExecutor().runSync(runnable, delay, unit); promise.onCancel((e) -> getExecutor().cancel(future)); } catch (RejectedExecutionException e) { promise.completeExceptionally(e); } }, promise::completeExceptionally ); propagateCancel(promise, this); return promise; } @Override public @NotNull Promise thenComposeSync(@NotNull ExceptionalFunction> task) { Promise promise = getFactory().unresolved(); thenApplySync(task).addDirectListener( nestedPromise -> { if (nestedPromise == null) { promise.complete(null); } else { propagateResult(nestedPromise, promise); propagateCancel(promise, nestedPromise); } }, promise::completeExceptionally ); propagateCancel(promise, this); return promise; } @Override public @NotNull Promise thenRunAsync(@NotNull ExceptionalRunnable task) { return thenApplyAsync(result -> { task.run(); return null; }); } @Override public @NotNull Promise thenRunDelayedAsync(@NotNull ExceptionalRunnable task, long delay, @NotNull TimeUnit unit) { return thenApplyDelayedAsync(result -> { task.run(); return null; }, delay, unit); } @Override public @NotNull Promise thenConsumeAsync(@NotNull ExceptionalConsumer task) { return thenApplyAsync(result -> { task.accept(result); return null; }); } @Override public @NotNull Promise thenConsumeDelayedAsync(@NotNull ExceptionalConsumer task, long delay, @NotNull TimeUnit unit) { return thenApplyDelayedAsync(result -> { task.accept(result); return null; }, delay, unit); } @Override public @NotNull Promise thenSupplyAsync(@NotNull ExceptionalSupplier task) { return thenApplyAsync(result -> task.get()); } @Override public @NotNull Promise thenSupplyDelayedAsync(@NotNull ExceptionalSupplier task, long delay, @NotNull TimeUnit unit) { return thenApplyDelayedAsync(result -> task.get(), delay, unit); } @Override public @NotNull Promise thenPopulateReference(@NotNull AtomicReference reference) { return thenApplyAsync((result) -> { reference.set(result); return result; }); } @Override public @NotNull Promise thenApplyAsync(@NotNull ExceptionalFunction task) { Promise promise = getFactory().unresolved(); addDirectListener( (res) -> { try { Runnable runnable = createRunnable(res, promise, task); F future = getExecutor().runAsync(runnable); promise.onCancel((e) -> getExecutor().cancel(future)); } catch (RejectedExecutionException e) { promise.completeExceptionally(e); } }, promise::completeExceptionally ); propagateCancel(promise, this); return promise; } @Override public @NotNull Promise thenApplyDelayedAsync(@NotNull ExceptionalFunction task, long delay, @NotNull TimeUnit unit) { Promise promise = getFactory().unresolved(); addDirectListener( res -> { try { Runnable runnable = createRunnable(res, promise, task); F future = getExecutor().runAsync(runnable, delay, unit); promise.onCancel((e) -> getExecutor().cancel(future)); } catch (RejectedExecutionException e) { promise.completeExceptionally(e); } }, promise::completeExceptionally ); propagateCancel(promise, this); return promise; } @Override public @NotNull Promise thenComposeAsync(@NotNull ExceptionalFunction> task) { Promise promise = getFactory().unresolved(); thenApplyAsync(task).addDirectListener( nestedPromise -> { if (nestedPromise == null) { promise.complete(null); } else { propagateResult(nestedPromise, promise); propagateCancel(promise, nestedPromise); } }, promise::completeExceptionally ); propagateCancel(promise, this); return promise; } @Override public @NotNull Promise erase() { return thenSupplyAsync(() -> null); } @Override public @NotNull Promise addAsyncListener(@NotNull AsyncPromiseListener listener) { return addAnyListener(listener); } @Override public @NotNull Promise addAsyncListener(@Nullable Consumer successListener, @Nullable Consumer errorListener) { return addAsyncListener((res) -> { if (res.isError()) { if (errorListener != null) errorListener.accept(res.getException()); } else { if (successListener != null) successListener.accept(res.getResult()); } }); } @Override public @NotNull Promise addDirectListener(@NotNull PromiseListener listener) { return addAnyListener(listener); } @Override public @NotNull Promise addDirectListener(@Nullable Consumer successListener, @Nullable Consumer errorListener) { return addDirectListener((res) -> { if (res.isError()) { if (errorListener != null) errorListener.accept(res.getException()); } else { if (successListener != null) successListener.accept(res.getResult()); } }); } private @NotNull Promise addAnyListener(PromiseListener listener) { PromiseCompletion completion; lock.lock(); try { completion = getCompletion(); if (completion == null) { if (listeners == null) listeners = new LinkedList<>(); listeners.add(listener); return this; } } finally { lock.unlock(); } callListener(listener, completion); return this; } private void callListener(PromiseListener listener, PromiseCompletion ctx) { if (listener instanceof AsyncPromiseListener) { try { getExecutor().runAsync(() -> callListenerNow(listener, ctx)); } catch (RejectedExecutionException ignored) { } } else { callListenerNow(listener, ctx); } } private void callListenerNow(PromiseListener listener, PromiseCompletion ctx) { try { listener.handle(ctx); } catch (Exception e) { getLogger().error("Exception caught in promise listener", e); } } @Override public @NotNull Promise onSuccess(@NotNull Consumer listener) { return addAsyncListener(listener, null); } @Override public @NotNull Promise onError(@NotNull Consumer listener) { return addAsyncListener(null, listener); } @Override public @NotNull Promise logExceptions(@NotNull String message) { return onError(e -> getLogger().error(message, e)); } @Override public @NotNull Promise onError(@NotNull Class clazz, @NotNull Consumer listener) { return onError((e) -> { if (clazz.isAssignableFrom(e.getClass())) { //noinspection unchecked listener.accept((E) e); } }); } @Override public @NotNull Promise onCancel(@NotNull Consumer listener) { return onError(CancellationException.class, listener); } @Deprecated @Override public @NotNull Promise timeout(long time, @NotNull TimeUnit unit) { return maxWaitTime(time, unit); } @Override public @NotNull Promise maxWaitTime(long time, @NotNull TimeUnit unit) { try { Exception e = new TimeoutException("Promise stopped waiting after " + time + " " + unit); F future = getExecutor().runAsync(() -> completeExceptionally(e), time, unit); return addDirectListener((_v) -> getExecutor().cancel(future)); } catch (RejectedExecutionException e) { completeExceptionally(e); return this; } } private void handleCompletion(@NotNull PromiseCompletion ctx) { lock.lock(); try { if (!setCompletion(ctx)) return; this.latch.countDown(); if (listeners != null) { for (PromiseListener listener : listeners) { callListener(listener, ctx); } } } finally { lock.unlock(); } } private boolean setCompletion(PromiseCompletion completion) { return this.completion.compareAndSet(null, completion); } @Override public void cancel(@Nullable String message) { completeExceptionally(new CancellationException(message)); } @Override public void complete(@Nullable T result) { handleCompletion(new PromiseCompletion<>(result)); } @Override public void completeExceptionally(@NotNull Throwable result) { handleCompletion(new PromiseCompletion<>(result)); } @Override public boolean isCompleted() { return completion.get() != null; } @Override public @Nullable PromiseCompletion getCompletion() { return completion.get(); } @Override public @NotNull CompletableFuture toFuture() { CompletableFuture future = new CompletableFuture<>(); this.addDirectListener(future::complete, future::completeExceptionally); future.whenComplete((res, e) -> { if (e instanceof CancellationException) { this.cancel(); } }); return future; } }