more optimized promise implementation for completed promises

This commit is contained in:
tommyskeff
2025-01-09 18:41:05 +00:00
parent 0eb9190621
commit ae15089b3d
8 changed files with 423 additions and 289 deletions

View File

@@ -9,47 +9,29 @@ import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Objects;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
import java.util.function.Consumer;
@SuppressWarnings({"FieldMayBeFinal", "unchecked"})
public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T> {
private static final VarHandle COMPLETION_HANDLE;
private static final VarHandle LISTENERS_HANDLE;
static {
try {
MethodHandles.Lookup lookup = MethodHandles.lookup();
COMPLETION_HANDLE = lookup.findVarHandle(AbstractPromise.class, "completion", PromiseCompletion.class);
LISTENERS_HANDLE = lookup.findVarHandle(AbstractPromise.class, "listeners", Collection.class);
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
}
private final Sync sync;
private volatile Collection<PromiseListener<T>> listeners;
private volatile PromiseCompletion<T> completion;
public AbstractPromise() {
this.sync = new Sync();
this.listeners = Collections.EMPTY_LIST;
this.completion = null;
}
public abstract class AbstractPromise<T, FS, FA> implements Promise<T> {
public abstract @NotNull AbstractPromiseFactory<FS, FA> getFactory();
private void runCompleter(@NotNull CompletablePromise<?> promise, @NotNull ExceptionalRunnable completer) {
protected abstract @NotNull Promise<T> addAnyListener(@NotNull PromiseListener<T> listener);
protected @NotNull Logger getLogger() {
return getFactory().getLogger();
}
protected void callListener(@NotNull PromiseListener<T> listener, @NotNull PromiseCompletion<T> cmp) {
if (listener instanceof AsyncPromiseListener) {
callListenerAsync(listener, cmp);
} else {
callListenerNow(listener, cmp);
}
}
protected void runCompleter(@NotNull CompletablePromise<?> promise, @NotNull ExceptionalRunnable completer) {
try {
completer.run();
} catch (Error e) {
@@ -60,11 +42,8 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
}
}
private <V> @NotNull Runnable createCompleter(
T result,
@NotNull CompletablePromise<V> promise,
@NotNull ExceptionalFunction<T, V> completer
) {
protected <V> @NotNull Runnable createCompleter(T result, @NotNull CompletablePromise<V> promise,
@NotNull ExceptionalFunction<T, V> completer) {
return () -> {
if (!promise.isCompleted()) {
runCompleter(promise, () -> promise.complete(completer.apply(result)));
@@ -72,43 +51,37 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
};
}
protected @NotNull Logger getLogger() {
return getFactory().getLogger();
protected <V> @NotNull CompletablePromise<V> createLinked() {
CompletablePromise<V> promise = getFactory().unresolved();
PromiseUtil.propagateCancel(promise, this);
return promise;
}
@Override
public T get() throws InterruptedException, ExecutionException {
sync.acquireSharedInterruptibly(1);
return joinCompletion();
}
@Override
public T get(long time, @NotNull TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
boolean success = sync.tryAcquireSharedNanos(1, unit.toNanos(time));
if (!success) {
throw new TimeoutException("Promise stopped waiting after " + time + " " + unit);
}
return joinCompletion();
}
@Override
public T await() {
protected void callListenerAsync(PromiseListener<T> listener, PromiseCompletion<T> res) {
try {
sync.acquireSharedInterruptibly(1);
} catch (InterruptedException e) {
throw new RuntimeException(e);
getFactory().getAsyncExecutor().run(() -> callListenerNow(listener, res));
} catch (RejectedExecutionException ignored) {
} catch (Exception e) {
getLogger().warn("Exception caught while running promise listener", e);
}
PromiseCompletion<T> completion = Objects.requireNonNull(getCompletion());
if (completion.isSuccess()) return completion.getResult();
throw new CompletionException(completion.getException());
}
private T joinCompletion() throws ExecutionException {
PromiseCompletion<T> completion = Objects.requireNonNull(getCompletion());
if (completion.isSuccess()) return completion.getResult();
throw new ExecutionException(completion.getException());
protected void callListenerNow(PromiseListener<T> listener, PromiseCompletion<T> res) {
try {
listener.handle(res);
} catch (Error e) {
getLogger().error("Error caught in promise listener", e);
throw e;
} catch (Throwable e) {
getLogger().error("Exception caught in promise listener", e);
}
}
protected void callListenerAsyncLastResort(PromiseListener<T> listener, PromiseCompletion<T> completion) {
try {
getFactory().getAsyncExecutor().run(() -> callListenerNow(listener, completion));
} catch (Throwable ignored) {
}
}
@Override
@@ -141,33 +114,68 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
@Override
public <V> @NotNull Promise<V> thenApply(@NotNull ExceptionalFunction<T, V> task) {
CompletablePromise<V> promise = getFactory().unresolved();
addDirectListener(
res -> createCompleter(res, promise, task).run(),
promise::completeExceptionally
);
PromiseCompletion<T> completion = getCompletion();
if (completion == null) {
CompletablePromise<V> promise = createLinked();
addDirectListener(
res -> createCompleter(res, promise, task).run(),
promise::completeExceptionally
);
PromiseUtil.propagateCancel(promise, this);
return promise;
return promise;
} else if (completion.isSuccess()) {
try {
V result = task.apply(completion.getResult());
return getFactory().resolve(result);
} catch (Exception e) {
return getFactory().error(e);
}
} else {
Throwable ex = completion.getException();
assert ex != null;
return getFactory().error(ex);
}
}
@Override
public <V> @NotNull Promise<V> thenCompose(@NotNull ExceptionalFunction<T, Promise<V>> task) {
CompletablePromise<V> promise = getFactory().unresolved();
thenApply(task).addDirectListener(
nestedPromise -> {
if (nestedPromise == null) {
promise.complete(null);
} else {
PromiseUtil.propagateCompletion(nestedPromise, promise);
PromiseUtil.propagateCancel(promise, nestedPromise);
}
},
promise::completeExceptionally
);
PromiseCompletion<T> completion = getCompletion();
if (completion == null) {
CompletablePromise<V> promise = createLinked();
thenApply(task).addDirectListener(
result -> {
if (result == null) {
promise.complete(null);
} else {
PromiseUtil.propagateCompletion(result, promise);
PromiseUtil.propagateCancel(promise, result);
}
},
promise::completeExceptionally
);
PromiseUtil.propagateCancel(promise, this);
return promise;
return promise;
} else if (completion.isSuccess()) {
try {
Promise<V> result = task.apply(completion.getResult());
if (result == null) {
return getFactory().resolve(null);
} else if (result.isCompleted()) {
return result;
} else {
CompletablePromise<V> promise = createLinked();
PromiseUtil.propagateCompletion(result, promise);
PromiseUtil.propagateCancel(promise, result);
return promise;
}
} catch (Exception e) {
return getFactory().error(e);
}
} else {
Throwable ex = completion.getException();
assert ex != null;
return getFactory().error(ex);
}
}
@Override
@@ -214,7 +222,7 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
@Override
public <V> @NotNull Promise<V> thenApplySync(@NotNull ExceptionalFunction<T, V> task) {
CompletablePromise<V> promise = getFactory().unresolved();
CompletablePromise<V> promise = createLinked();
addDirectListener(
res -> runCompleter(promise, () -> {
Runnable runnable = createCompleter(res, promise, task);
@@ -224,13 +232,12 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
promise::completeExceptionally
);
PromiseUtil.propagateCancel(promise, this);
return promise;
}
@Override
public <V> @NotNull Promise<V> thenApplyDelayedSync(@NotNull ExceptionalFunction<T, V> task, long delay, @NotNull TimeUnit unit) {
CompletablePromise<V> promise = getFactory().unresolved();
CompletablePromise<V> promise = createLinked();
addDirectListener(
res -> runCompleter(promise, () -> {
Runnable runnable = createCompleter(res, promise, task);
@@ -240,13 +247,12 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
promise::completeExceptionally
);
PromiseUtil.propagateCancel(promise, this);
return promise;
}
@Override
public <V> @NotNull Promise<V> thenComposeSync(@NotNull ExceptionalFunction<T, Promise<V>> task) {
CompletablePromise<V> promise = getFactory().unresolved();
CompletablePromise<V> promise = createLinked();
thenApplySync(task).addDirectListener(
nestedPromise -> {
if (nestedPromise == null) {
@@ -259,7 +265,6 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
promise::completeExceptionally
);
PromiseUtil.propagateCancel(promise, this);
return promise;
}
@@ -307,7 +312,7 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
@Override
public <V> @NotNull Promise<V> thenApplyAsync(@NotNull ExceptionalFunction<T, V> task) {
CompletablePromise<V> promise = getFactory().unresolved();
CompletablePromise<V> promise = createLinked();
addDirectListener(
(res) -> runCompleter(promise, () -> {
Runnable runnable = createCompleter(res, promise, task);
@@ -317,13 +322,12 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
promise::completeExceptionally
);
PromiseUtil.propagateCancel(promise, this);
return promise;
}
@Override
public <V> @NotNull Promise<V> thenApplyDelayedAsync(@NotNull ExceptionalFunction<T, V> task, long delay, @NotNull TimeUnit unit) {
CompletablePromise<V> promise = getFactory().unresolved();
CompletablePromise<V> promise = createLinked();
addDirectListener(
res -> runCompleter(promise, () -> {
Runnable runnable = createCompleter(res, promise, task);
@@ -333,13 +337,12 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
promise::completeExceptionally
);
PromiseUtil.propagateCancel(promise, this);
return promise;
}
@Override
public <V> @NotNull Promise<V> thenComposeAsync(@NotNull ExceptionalFunction<T, Promise<V>> task) {
CompletablePromise<V> promise = getFactory().unresolved();
CompletablePromise<V> promise = createLinked();
thenApplyAsync(task).addDirectListener(
nestedPromise -> {
if (nestedPromise == null) {
@@ -352,7 +355,6 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
promise::completeExceptionally
);
PromiseUtil.propagateCancel(promise, this);
return promise;
}
@@ -401,51 +403,6 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
});
}
private @NotNull Promise<T> addAnyListener(PromiseListener<T> listener) {
Collection<PromiseListener<T>> prev = listeners, next = null;
for (boolean haveNext = false; ; ) {
if (!haveNext) {
next = prev == Collections.EMPTY_LIST ? new ConcurrentLinkedQueue<>() : prev;
if (next != null) next.add(listener);
}
if (LISTENERS_HANDLE.weakCompareAndSet(this, prev, next))
break;
haveNext = (prev == (prev = listeners));
}
if (next == null) {
if (listener instanceof AsyncPromiseListener) {
callListenerAsync(listener, Objects.requireNonNull(getCompletion()));
} else {
callListenerNow(listener, Objects.requireNonNull(getCompletion()));
}
}
return this;
}
private void callListenerAsync(PromiseListener<T> listener, PromiseCompletion<T> res) {
try {
getFactory().getAsyncExecutor().run(() -> callListenerNow(listener, res));
} catch (RejectedExecutionException ignored) {
} catch (Exception e) {
getLogger().warn("Exception caught while running promise listener", e);
}
}
private void callListenerNow(PromiseListener<T> listener, PromiseCompletion<T> res) {
try {
listener.handle(res);
} catch (Error e) {
getLogger().error("Error caught in promise listener", e);
throw e;
} catch (Throwable e) {
getLogger().error("Exception caught in promise listener", e);
}
}
@Override
public @NotNull Promise<T> onSuccess(@NotNull Consumer<T> listener) {
return addAsyncListener(listener, null);
@@ -489,92 +446,11 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
@Override
public @NotNull Promise<T> orDefault(@NotNull ExceptionalFunction<Throwable, T> function) {
CompletablePromise<T> promise = getFactory().unresolved();
addDirectListener(promise::complete, e -> {
try {
T result = function.apply(e);
promise.complete(result);
} catch (Exception ex) {
promise.completeExceptionally(ex);
}
});
PromiseUtil.propagateCancel(promise, this);
CompletablePromise<T> promise = createLinked();
addDirectListener(promise::complete, e -> runCompleter(promise, () -> promise.complete(function.apply(e))));
return promise;
}
@Override
public @NotNull Promise<T> timeout(long time, @NotNull TimeUnit unit) {
Exception e = new CancellationException("Promise timed out after " + time + " " + unit.toString().toLowerCase());
return completeExceptionallyDelayed(e, time, unit);
}
@Override
public @NotNull Promise<T> maxWaitTime(long time, @NotNull TimeUnit unit) {
Exception e = new TimeoutException("Promise stopped waiting after " + time + " " + unit.toString().toLowerCase());
return completeExceptionallyDelayed(e, time, unit);
}
private Promise<T> completeExceptionallyDelayed(Throwable e, long delay, TimeUnit unit) {
runCompleter(this, () -> {
FA future = getFactory().getAsyncExecutor().run(() -> completeExceptionally(e), delay, unit);
addDirectListener(_ -> getFactory().getAsyncExecutor().cancel(future));
});
return this;
}
private void handleCompletion(@NotNull PromiseCompletion<T> cmp) {
if (!COMPLETION_HANDLE.compareAndSet(this, null, cmp)) return;
sync.releaseShared(1);
Iterator<PromiseListener<T>> iter = ((Iterable<PromiseListener<T>>) LISTENERS_HANDLE.getAndSet(this, null)).iterator();
try {
while (iter.hasNext()) {
PromiseListener<T> listener = iter.next();
if (listener instanceof AsyncPromiseListener) {
callListenerAsync(listener, cmp);
} else {
callListenerNow(listener, cmp);
}
}
} finally {
iter.forEachRemaining(v -> callListenerAsyncLastResort(v, cmp));
}
}
private void callListenerAsyncLastResort(PromiseListener<T> listener, PromiseCompletion<T> completion) {
try {
getFactory().getAsyncExecutor().run(() -> callListenerNow(listener, completion));
} catch (Throwable ignored) {
}
}
@Override
public void cancel(@NotNull CancellationException e) {
completeExceptionally(e);
}
@Override
public void complete(@Nullable T result) {
handleCompletion(new PromiseCompletion<>(result));
}
@Override
public void completeExceptionally(@NotNull Throwable result) {
handleCompletion(new PromiseCompletion<>(result));
}
@Override
public boolean isCompleted() {
return completion != null;
}
@Override
public @Nullable PromiseCompletion<T> getCompletion() {
return completion;
}
@Override
public @NotNull CompletableFuture<T> toFuture() {
CompletableFuture<T> future = new CompletableFuture<>();
@@ -591,31 +467,4 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
private static class DeferredExecutionException extends ExecutionException {
}
private static final class Sync extends AbstractQueuedSynchronizer {
private Sync() {
setState(1);
}
@Override
protected int tryAcquireShared(int acquires) {
return getState() == 0 ? 1 : -1;
}
@Override
protected boolean tryReleaseShared(int releases) {
int c1, c2;
do {
c1 = getState();
if (c1 == 0) {
return false;
}
c2 = c1 - 1;
} while (!compareAndSetState(c1, c2));
return c2 == 0;
}
}
}