cleanup AbstractPromise (still messy)

This commit is contained in:
WhatCats
2025-01-10 22:14:59 +01:00
parent 4d01a8a418
commit 5447e06455
3 changed files with 149 additions and 127 deletions

View File

@@ -12,6 +12,8 @@ import org.slf4j.Logger;
import java.util.concurrent.*; import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
public abstract class AbstractPromise<T, FS, FA> implements Promise<T> { public abstract class AbstractPromise<T, FS, FA> implements Promise<T> {
@@ -31,17 +33,40 @@ public abstract class AbstractPromise<T, FS, FA> implements Promise<T> {
} }
} }
protected void runCompleter(@NotNull CompletablePromise<?> promise, @NotNull ExceptionalRunnable completer) { protected <V> V supplySafe(@NotNull ExceptionalSupplier<V> supplier, @NotNull Function<Throwable, V> handler) {
try { try {
completer.run(); return supplier.get();
} catch (Error e) { } catch (Error error) {
promise.completeExceptionally(e); // Rethrow error so the Thread can shut down
throw e; throw error;
} catch (Throwable e) { } catch (Throwable e) {
promise.completeExceptionally(e); return handler.apply(e);
} }
} }
protected void runSafe(@NotNull ExceptionalRunnable runnable, @NotNull Consumer<Throwable> handler) {
try {
runnable.run();
} catch (Error error) {
handler.accept(error);
// Rethrow error so the Thread can shut down
throw error;
} catch (Throwable e) {
handler.accept(e);
}
}
protected void runCompleter(@NotNull CompletablePromise<?> promise, @NotNull ExceptionalRunnable completer) {
runSafe(completer, promise::completeExceptionally);
}
protected <V> V useCompletion(Supplier<V> unresolved, Function<T, V> completed, Function<Throwable, V> failed) {
PromiseCompletion<T> completion = getCompletion();
if (completion == null) return unresolved.get();
else if (completion.isSuccess()) return completed.apply(completion.getResult());
else return failed.apply(completion.getException());
}
protected <V> @NotNull Runnable createCompleter(T result, @NotNull CompletablePromise<V> promise, protected <V> @NotNull Runnable createCompleter(T result, @NotNull CompletablePromise<V> promise,
@NotNull ExceptionalFunction<T, V> completer) { @NotNull ExceptionalFunction<T, V> completer) {
return () -> { return () -> {
@@ -67,14 +92,7 @@ public abstract class AbstractPromise<T, FS, FA> implements Promise<T> {
} }
protected void callListenerNow(PromiseListener<T> listener, PromiseCompletion<T> res) { protected void callListenerNow(PromiseListener<T> listener, PromiseCompletion<T> res) {
try { runSafe(() -> listener.handle(res), e -> getLogger().error("Exception caught in promise listener", e));
listener.handle(res);
} catch (Error e) {
getLogger().error("Error caught in promise listener", e);
throw e;
} catch (Throwable e) {
getLogger().error("Exception caught in promise listener", e);
}
} }
protected void callListenerAsyncLastResort(PromiseListener<T> listener, PromiseCompletion<T> completion) { protected void callListenerAsyncLastResort(PromiseListener<T> listener, PromiseCompletion<T> completion) {
@@ -84,16 +102,27 @@ public abstract class AbstractPromise<T, FS, FA> implements Promise<T> {
} }
} }
protected T joinCompletionChecked() throws ExecutionException {
PromiseCompletion<T> completion = getCompletion();
assert completion != null;
if (completion.isSuccess()) return completion.getResult();
throw new ExecutionException(completion.getException());
}
protected T joinCompletionUnchecked() {
PromiseCompletion<T> completion = getCompletion();
assert completion != null;
if (completion.isSuccess()) return completion.getResult();
throw new CompletionException(completion.getException());
}
@Override @Override
public @NotNull Promise<T> fork() { public @NotNull Promise<T> fork() {
PromiseCompletion<T> completion = getCompletion(); if (isCompleted()) return this;
if (completion == null) {
CompletablePromise<T> fork = getFactory().unresolved(); CompletablePromise<T> fork = getFactory().unresolved();
PromiseUtil.propagateCompletion(this, fork); PromiseUtil.propagateCompletion(this, fork);
return fork; return fork;
} else {
return this;
}
} }
@Override @Override
@@ -119,68 +148,61 @@ public abstract class AbstractPromise<T, FS, FA> implements Promise<T> {
@Override @Override
public <V> @NotNull Promise<V> thenApply(@NotNull ExceptionalFunction<T, V> task) { public <V> @NotNull Promise<V> thenApply(@NotNull ExceptionalFunction<T, V> task) {
PromiseCompletion<T> completion = getCompletion(); return useCompletion(
if (completion == null) { () -> {
CompletablePromise<V> promise = createLinked(); CompletablePromise<V> promise = createLinked();
addDirectListener( addDirectListener(
res -> createCompleter(res, promise, task).run(), res -> createCompleter(res, promise, task).run(),
promise::completeExceptionally promise::completeExceptionally
); );
return promise; return promise;
} else if (completion.isSuccess()) { },
try { result -> supplySafe(
V result = task.apply(completion.getResult()); () -> getFactory().resolve(task.apply(result)),
return getFactory().resolve(result); getFactory()::error
} catch (Exception e) { ),
return getFactory().error(e); getFactory()::error
} );
} else {
Throwable ex = completion.getException();
assert ex != null;
return getFactory().error(ex);
}
} }
@Override @Override
public <V> @NotNull Promise<V> thenCompose(@NotNull ExceptionalFunction<T, Promise<V>> task) { public <V> @NotNull Promise<V> thenCompose(@NotNull ExceptionalFunction<T, Promise<V>> task) {
PromiseCompletion<T> completion = getCompletion(); return useCompletion(
if (completion == null) { () -> {
CompletablePromise<V> promise = createLinked(); CompletablePromise<V> promise = createLinked();
thenApply(task).addDirectListener( thenApply(task).addDirectListener(
result -> { result -> {
if (result == null) { if (result == null) {
promise.complete(null); promise.complete(null);
} else {
PromiseUtil.propagateCompletion(result, promise);
PromiseUtil.propagateCancel(promise, result);
}
},
promise::completeExceptionally
);
return promise;
},
result -> supplySafe(
() -> {
Promise<V> nested = task.apply(result);
if (nested == null) {
return getFactory().resolve(null);
} else if (nested.isCompleted()) {
return nested;
} else { } else {
PromiseUtil.propagateCompletion(result, promise); CompletablePromise<V> promise = createLinked();
PromiseUtil.propagateCancel(promise, result); PromiseUtil.propagateCompletion(nested, promise);
PromiseUtil.propagateCancel(promise, nested);
return promise;
} }
}, },
promise::completeExceptionally getFactory()::error
); ),
getFactory()::error
return promise; );
} else if (completion.isSuccess()) {
try {
Promise<V> result = task.apply(completion.getResult());
if (result == null) {
return getFactory().resolve(null);
} else if (result.isCompleted()) {
return result;
} else {
CompletablePromise<V> promise = createLinked();
PromiseUtil.propagateCompletion(result, promise);
PromiseUtil.propagateCancel(promise, result);
return promise;
}
} catch (Exception e) {
return getFactory().error(e);
}
} else {
Throwable ex = completion.getException();
assert ex != null;
return getFactory().error(ex);
}
} }
@Override @Override
@@ -451,36 +473,38 @@ public abstract class AbstractPromise<T, FS, FA> implements Promise<T> {
@Override @Override
public @NotNull Promise<T> orDefault(@NotNull ExceptionalFunction<Throwable, T> function) { public @NotNull Promise<T> orDefault(@NotNull ExceptionalFunction<Throwable, T> function) {
PromiseCompletion<T> completion = getCompletion(); return useCompletion(
if (completion == null) { () -> {
CompletablePromise<T> promise = createLinked(); CompletablePromise<T> promise = createLinked();
addDirectListener(promise::complete, e -> runCompleter(promise, () -> promise.complete(function.apply(e)))); addDirectListener(promise::complete, e -> runCompleter(promise, () -> promise.complete(function.apply(e))));
return promise; return promise;
} else if (completion.isSuccess()) { },
return getFactory().resolve(completion.getResult()); getFactory()::resolve,
} else { getFactory()::error
try { );
return getFactory().resolve(function.apply(completion.getException()));
} catch (Exception e) {
return getFactory().error(e);
}
}
} }
@Override @Override
public @NotNull CompletableFuture<T> toFuture() { public @NotNull CompletableFuture<T> toFuture() {
CompletableFuture<T> future = new CompletableFuture<>(); return useCompletion(
addDirectListener(future::complete, future::completeExceptionally); () -> {
future.whenComplete((_, e) -> { CompletableFuture<T> future = new CompletableFuture<>();
if (e instanceof CancellationException) { addDirectListener(future::complete, future::completeExceptionally);
cancel(); future.whenComplete((_, e) -> {
} if (e instanceof CancellationException) {
}); cancel();
}
});
return future; return future;
},
CompletableFuture::completedFuture,
CompletableFuture::failedFuture
);
} }
private static class DeferredExecutionException extends ExecutionException { private static class DeferredExecutionException extends ExecutionException {
} }
} }

View File

@@ -42,12 +42,6 @@ public abstract class BasePromise<T, FS, FA> extends AbstractPromise<T, FS, FA>
this.listeners = Collections.EMPTY_LIST; this.listeners = Collections.EMPTY_LIST;
} }
protected T joinCompletion() throws ExecutionException {
PromiseCompletion<T> completion = Objects.requireNonNull(getCompletion());
if (completion.isSuccess()) return completion.getResult();
throw new ExecutionException(completion.getException());
}
protected void handleCompletion(@NotNull PromiseCompletion<T> cmp) { protected void handleCompletion(@NotNull PromiseCompletion<T> cmp) {
if (!COMPLETION_HANDLE.compareAndSet(this, null, cmp)) return; if (!COMPLETION_HANDLE.compareAndSet(this, null, cmp)) return;
sync.releaseShared(1); sync.releaseShared(1);
@@ -99,31 +93,36 @@ public abstract class BasePromise<T, FS, FA> extends AbstractPromise<T, FS, FA>
@Override @Override
public T get() throws InterruptedException, ExecutionException { public T get() throws InterruptedException, ExecutionException {
sync.acquireSharedInterruptibly(1); if (!isCompleted()) {
return joinCompletion(); sync.acquireSharedInterruptibly(1);
}
return joinCompletionChecked();
} }
@Override @Override
public T get(long time, @NotNull TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { public T get(long time, @NotNull TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
boolean success = sync.tryAcquireSharedNanos(1, unit.toNanos(time)); if (!isCompleted()) {
if (!success) { boolean success = sync.tryAcquireSharedNanos(1, unit.toNanos(time));
throw new TimeoutException("Promise stopped waiting after " + time + " " + unit); if (!success) {
throw new TimeoutException("Promise stopped waiting after " + time + " " + unit);
}
} }
return joinCompletion(); return joinCompletionChecked();
} }
@Override @Override
public T await() { public T await() {
try { if (!isCompleted()) {
sync.acquireSharedInterruptibly(1); try {
} catch (InterruptedException e) { sync.acquireSharedInterruptibly(1);
throw new RuntimeException(e); } catch (InterruptedException e) {
throw new RuntimeException(e);
}
} }
PromiseCompletion<T> completion = Objects.requireNonNull(getCompletion()); return joinCompletionUnchecked();
if (completion.isSuccess()) return completion.getResult();
throw new CompletionException(completion.getException());
} }
@Override @Override

View File

@@ -3,6 +3,7 @@ package dev.tommyjs.futur.promise;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
public abstract class CompletedPromise<T, FS, FA> extends AbstractPromise<T, FS, FA> { public abstract class CompletedPromise<T, FS, FA> extends AbstractPromise<T, FS, FA> {
@@ -28,36 +29,34 @@ public abstract class CompletedPromise<T, FS, FA> extends AbstractPromise<T, FS,
@Override @Override
public @NotNull Promise<T> timeout(long time, @NotNull TimeUnit unit) { public @NotNull Promise<T> timeout(long time, @NotNull TimeUnit unit) {
// Promise is already completed so can't time out
return this; return this;
} }
@Override @Override
public @NotNull Promise<T> maxWaitTime(long time, @NotNull TimeUnit unit) { public @NotNull Promise<T> maxWaitTime(long time, @NotNull TimeUnit unit) {
// Promise is already completed so can't time out
return this; return this;
} }
@Override @Override
public void cancel(@NotNull CancellationException exception) { public void cancel(@NotNull CancellationException exception) {
// Promise is already completed so can't be cancelled
} }
@Override @Override
public T get() { public T get() throws ExecutionException {
return null; return joinCompletionChecked();
} }
@Override @Override
public T get(long timeout, @NotNull TimeUnit unit) { public T get(long timeout, @NotNull TimeUnit unit) throws ExecutionException {
return null; return joinCompletionChecked();
} }
@Override @Override
public T await() { public T await() {
return null; return joinCompletionUnchecked();
}
@Override
public @NotNull Promise<T> fork() {
return this;
} }
@Override @Override