diff --git a/build.gradle b/build.gradle index 327d71d..4252922 100644 --- a/build.gradle +++ b/build.gradle @@ -38,7 +38,6 @@ subprojects { testRuntimeOnly 'org.junit.platform:junit-platform-launcher' testImplementation 'io.projectreactor:reactor-core:3.6.4' testImplementation 'org.junit.jupiter:junit-jupiter:5.8.1' - testImplementation 'org.slf4j:slf4j-api:2.0.12' testImplementation 'ch.qos.logback:logback-classic:1.5.3' } diff --git a/futur-api/src/main/java/dev/tommyjs/futur/promise/AbstractPromise.java b/futur-api/src/main/java/dev/tommyjs/futur/promise/AbstractPromise.java index 05b611e..5fa166f 100644 --- a/futur-api/src/main/java/dev/tommyjs/futur/promise/AbstractPromise.java +++ b/futur-api/src/main/java/dev/tommyjs/futur/promise/AbstractPromise.java @@ -12,16 +12,21 @@ import org.slf4j.Logger; import java.util.Collection; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; public abstract class AbstractPromise implements Promise { private final AtomicReference>> listeners; private final AtomicReference> completion; + private final CountDownLatch latch; + private final ReentrantLock lock; public AbstractPromise() { this.listeners = new AtomicReference<>(); this.completion = new AtomicReference<>(); + this.latch = new CountDownLatch(1); + this.lock = new ReentrantLock(); } protected static void propagateResult(Promise from, Promise to) { @@ -57,24 +62,14 @@ public abstract class AbstractPromise implements Promise { @Override public T join(long timeoutMillis) throws TimeoutException { - PromiseCompletion completion; - long start = System.currentTimeMillis(); - long remainingTimeout = timeoutMillis; - - synchronized (this.completion) { - completion = this.completion.get(); - while (completion == null && remainingTimeout > 0) { - try { - this.completion.wait(remainingTimeout); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - - completion = this.completion.get(); - remainingTimeout = timeoutMillis - (System.currentTimeMillis() - start); - } + try { + //noinspection ResultOfMethodCallIgnored + this.latch.await(timeoutMillis, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + throw new RuntimeException(e); } + PromiseCompletion completion = getCompletion(); if (completion == null) throw new TimeoutException("Promise stopped waiting after " + timeoutMillis + "ms"); @@ -320,7 +315,8 @@ public abstract class AbstractPromise implements Promise { } private @NotNull Promise addAnyListener(PromiseListener listener) { - synchronized (completion) { + lock.lock(); + try { PromiseCompletion completion = getCompletion(); if (completion != null) { callListener(listener, completion); @@ -328,6 +324,8 @@ public abstract class AbstractPromise implements Promise { listeners.compareAndSet(null, new ConcurrentLinkedQueue<>()); listeners.get().add(listener); } + } finally { + lock.unlock(); } return this; @@ -392,17 +390,19 @@ public abstract class AbstractPromise implements Promise { } private void handleCompletion(@NotNull PromiseCompletion ctx) { - synchronized (completion) { - if (!setCompletion(ctx)) return; - - completion.notifyAll(); + if (!setCompletion(ctx)) return; + lock.lock(); + try { + this.latch.countDown(); Collection> listeners = this.listeners.get(); if (listeners != null) { for (PromiseListener listener : listeners) { callListener(listener, ctx); } } + } finally { + lock.unlock(); } } diff --git a/futur-api/src/test/java/dev/tommyjs/futur/PromiseTests.java b/futur-api/src/test/java/dev/tommyjs/futur/PromiseTests.java index 790cf7a..ffe6f01 100644 --- a/futur-api/src/test/java/dev/tommyjs/futur/PromiseTests.java +++ b/futur-api/src/test/java/dev/tommyjs/futur/PromiseTests.java @@ -160,10 +160,10 @@ public final class PromiseTests { public void testRace() throws TimeoutException { assert pfac.race( List.of( - pfac.start().thenSupplyDelayedAsync(() -> true, 50, TimeUnit.MILLISECONDS), - pfac.start().thenSupplyDelayedAsync(() -> false, 150, TimeUnit.MILLISECONDS) + pfac.start().thenSupplyDelayedAsync(() -> true, 150, TimeUnit.MILLISECONDS), + pfac.start().thenSupplyDelayedAsync(() -> false, 200, TimeUnit.MILLISECONDS) ) - ).join(100L); + ).join(300L); } }