small promise optimizations

This commit is contained in:
tommyskeff
2025-01-07 14:22:33 +00:00
parent 8dbbc66de4
commit fbeef9833b

View File

@@ -9,24 +9,40 @@ import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger; import org.slf4j.Logger;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.*; import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
import java.util.function.Consumer; import java.util.function.Consumer;
public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T> { public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T> {
private static final VarHandle COMPLETION_HANDLE;
static {
try {
MethodHandles.Lookup lookup = MethodHandles.lookup();
COMPLETION_HANDLE = lookup.findVarHandle(AbstractPromise.class, "completion", PromiseCompletion.class);
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
}
private final AtomicReference<Collection<PromiseListener<T>>> listeners; private final AtomicReference<Collection<PromiseListener<T>>> listeners;
private final AtomicReference<PromiseCompletion<T>> completion; private final Sync sync;
private final CountDownLatch latch;
@SuppressWarnings("FieldMayBeFinal")
private volatile PromiseCompletion<T> completion;
public AbstractPromise() { public AbstractPromise() {
this.listeners = new AtomicReference<>(Collections.emptyList()); this.listeners = new AtomicReference<>(Collections.emptyList());
this.completion = new AtomicReference<>(); this.sync = new Sync();
this.latch = new CountDownLatch(1); this.completion = null;
} }
public abstract @NotNull AbstractPromiseFactory<FS, FA> getFactory(); public abstract @NotNull AbstractPromiseFactory<FS, FA> getFactory();
@@ -60,13 +76,13 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
@Override @Override
public T get() throws InterruptedException, ExecutionException { public T get() throws InterruptedException, ExecutionException {
this.latch.await(); sync.acquireSharedInterruptibly(1);
return joinCompletion(); return joinCompletion();
} }
@Override @Override
public T get(long time, @NotNull TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { public T get(long time, @NotNull TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
boolean success = this.latch.await(time, unit); boolean success = sync.tryAcquireSharedNanos(1, unit.toNanos(time));
if (!success) { if (!success) {
throw new TimeoutException("Promise stopped waiting after " + time + " " + unit); throw new TimeoutException("Promise stopped waiting after " + time + " " + unit);
} }
@@ -77,7 +93,7 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
@Override @Override
public T await() { public T await() {
try { try {
this.latch.await(); sync.acquireSharedInterruptibly(1);
} catch (InterruptedException e) { } catch (InterruptedException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@@ -472,36 +488,29 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
return this; return this;
} }
private void handleCompletion(@NotNull PromiseCompletion<T> ctx) { private void handleCompletion(@NotNull PromiseCompletion<T> cmp) {
if (!setCompletion(ctx)) return; if (!COMPLETION_HANDLE.compareAndSet(this, null, cmp)) return;
latch.countDown(); sync.releaseShared(1);
Iterator<PromiseListener<T>> iter = listeners.getAndSet(null).iterator(); Iterator<PromiseListener<T>> iter = listeners.getAndSet(null).iterator();
while (iter.hasNext()) {
PromiseListener<T> listener = iter.next();
try {
if (listener instanceof AsyncPromiseListener) {
callListenerAsync(listener, ctx);
} else {
callListenerNow(listener, ctx);
}
} finally {
iter.forEachRemaining(v -> callListenerAsyncLastResort(v, ctx));
}
}
}
private void callListenerAsyncLastResort(PromiseListener<T> listener, PromiseCompletion<T> ctx) {
try { try {
getFactory().getAsyncExecutor().run(() -> callListenerNow(listener, ctx)); while (iter.hasNext()) {
} catch (Throwable ignored) { PromiseListener<T> listener = iter.next();
if (listener instanceof AsyncPromiseListener) {
callListenerAsync(listener, cmp);
} else {
callListenerNow(listener, cmp);
}
}
} finally {
iter.forEachRemaining(v -> callListenerAsyncLastResort(v, cmp));
} }
} }
private boolean setCompletion(PromiseCompletion<T> completion) { private void callListenerAsyncLastResort(PromiseListener<T> listener, PromiseCompletion<T> completion) {
return this.completion.compareAndSet(null, completion); try {
getFactory().getAsyncExecutor().run(() -> callListenerNow(listener, completion));
} catch (Throwable ignored) {}
} }
@Override @Override
@@ -521,12 +530,12 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
@Override @Override
public boolean isCompleted() { public boolean isCompleted() {
return completion.get() != null; return completion != null;
} }
@Override @Override
public @Nullable PromiseCompletion<T> getCompletion() { public @Nullable PromiseCompletion<T> getCompletion() {
return completion.get(); return completion;
} }
@Override @Override
@@ -543,7 +552,33 @@ public abstract class AbstractPromise<T, FS, FA> implements CompletablePromise<T
} }
private static class DeferredExecutionException extends ExecutionException { private static class DeferredExecutionException extends ExecutionException {
}
private static final class Sync extends AbstractQueuedSynchronizer {
private Sync() {
setState(1);
}
@Override
protected int tryAcquireShared(int acquires) {
return getState() == 0 ? 1 : -1;
}
@Override
protected boolean tryReleaseShared(int releases) {
int c1, c2;
do {
c1 = getState();
if (c1 == 0) {
return false;
}
c2 = c1 - 1;
} while(!compareAndSetState(c1, c2));
return c2 == 0;
}
} }
} }