add cancellation and refractor PromiseFactory

This commit is contained in:
WhatCats
2024-04-04 17:36:35 +02:00
parent e6eee4e849
commit 5bbcfdc9b3
24 changed files with 670 additions and 445 deletions

View File

@@ -10,12 +10,14 @@ import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import java.util.Collection;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
public abstract class AbstractPromise<T> implements Promise<T> {
public abstract class AbstractPromise<T, F> implements Promise<T> {
private final Collection<PromiseListener<T>> listeners;
private final AtomicReference<PromiseCompletion<T>> completion;
@@ -25,14 +27,14 @@ public abstract class AbstractPromise<T> implements Promise<T> {
this.completion = new AtomicReference<>();
}
protected abstract PromiseExecutor getExecutor();
public abstract @NotNull AbstractPromiseFactory<F> getFactory();
protected abstract Logger getLogger();
protected @NotNull PromiseExecutor<F> getExecutor() {
return getFactory().getExecutor();
}
@Deprecated
@Override
public T join(long interval, long timeoutMillis) throws TimeoutException {
return join(timeoutMillis);
protected @NotNull Logger getLogger() {
return getFactory().getLogger();
}
@Override
@@ -56,7 +58,7 @@ public abstract class AbstractPromise<T> implements Promise<T> {
}
if (completion == null)
throw new TimeoutException("Promise timed out after " + timeoutMillis + "ms");
throw new TimeoutException("Promise stopped waiting after " + timeoutMillis + "ms");
return joinCompletion(completion);
}
@@ -113,57 +115,47 @@ public abstract class AbstractPromise<T> implements Promise<T> {
@Override
public <V> @NotNull Promise<V> thenApplySync(@NotNull ExceptionalFunction<T, V> task) {
Promise<V> promise = getFactory().unresolved();
addListener(ctx -> {
if (ctx.isError()) {
//noinspection ConstantConditions
promise.completeExceptionally(ctx.getException());
return;
}
Runnable runnable = createRunnable(ctx, promise, task);
getExecutor().runSync(runnable, 0L, TimeUnit.MILLISECONDS);
});
addListener(
res -> {
Runnable runnable = createRunnable(res, promise, task);
F future = getExecutor().runSync(runnable);
promise.onCancel((e) -> getExecutor().cancel(future));
},
promise::completeExceptionally
);
addChild(promise);
return promise;
}
@Override
public <V> @NotNull Promise<V> thenApplyDelayedSync(@NotNull ExceptionalFunction<T, V> task, long delay, @NotNull TimeUnit unit) {
Promise<V> promise = getFactory().unresolved();
addListener(ctx -> {
if (ctx.isError()) {
//noinspection ConstantConditions
promise.completeExceptionally(ctx.getException());
return;
}
Runnable runnable = createRunnable(ctx, promise, task);
getExecutor().runSync(runnable, delay, unit);
});
addListener(
res -> {
Runnable runnable = createRunnable(res, promise, task);
F future = getExecutor().runSync(runnable, delay, unit);
promise.onCancel((e) -> getExecutor().cancel(future));
},
promise::completeExceptionally
);
addChild(promise);
return promise;
}
@Override
public <V> @NotNull Promise<V> thenComposeSync(@NotNull ExceptionalFunction<T, @NotNull Promise<V>> task) {
Promise<V> promise = getFactory().unresolved();
thenApplySync(task).thenConsumeAsync(nestedPromise -> {
nestedPromise.addListener(ctx1 -> {
if (ctx1.isError()) {
//noinspection ConstantConditions
promise.completeExceptionally(ctx1.getException());
return;
}
promise.complete(ctx1.getResult());
});
}).addListener(ctx2 -> {
if (ctx2.isError()) {
//noinspection ConstantConditions
promise.completeExceptionally(ctx2.getException());
}
});
thenApplySync(task).addListener(
nestedPromise -> {
nestedPromise.propagateResult(promise);
nestedPromise.addChild(promise);
},
promise::completeExceptionally
);
addChild(promise);
return promise;
}
@@ -220,83 +212,66 @@ public abstract class AbstractPromise<T> implements Promise<T> {
@Override
public <V> @NotNull Promise<V> thenApplyAsync(@NotNull ExceptionalFunction<T, V> task) {
Promise<V> promise = getFactory().unresolved();
addListener(ctx -> {
if (ctx.isError()) {
//noinspection ConstantConditions
promise.completeExceptionally(ctx.getException());
return;
}
Runnable runnable = createRunnable(ctx, promise, task);
getExecutor().runAsync(runnable, 0L, TimeUnit.MILLISECONDS);
});
addListener(
(res) -> {
Runnable runnable = createRunnable(res, promise, task);
F future = getExecutor().runAsync(runnable);
promise.onCancel((e) -> getExecutor().cancel(future));
},
promise::completeExceptionally
);
addChild(promise);
return promise;
}
@Override
public <V> @NotNull Promise<V> thenApplyDelayedAsync(@NotNull ExceptionalFunction<T, V> task, long delay, @NotNull TimeUnit unit) {
Promise<V> promise = getFactory().unresolved();
addListener(ctx -> {
Runnable runnable = createRunnable(ctx, promise, task);
getExecutor().runAsync(runnable, delay, unit);
});
addListener(
res -> {
Runnable runnable = createRunnable(res, promise, task);
F future = getExecutor().runAsync(runnable, delay, unit);
promise.onCancel((e) -> getExecutor().cancel(future));
},
promise::completeExceptionally
);
addChild(promise);
return promise;
}
@Override
public <V> @NotNull Promise<V> thenComposeAsync(@NotNull ExceptionalFunction<T, Promise<V>> task) {
Promise<V> promise = getFactory().unresolved();
thenApplyAsync(task).thenConsumeAsync(nestedPromise -> {
nestedPromise.addListener(ctx1 -> {
if (ctx1.isError()) {
//noinspection ConstantConditions
promise.completeExceptionally(ctx1.getException());
return;
}
promise.complete(ctx1.getResult());
});
}).addListener(ctx2 -> {
if (ctx2.isError()) {
//noinspection ConstantConditions
promise.completeExceptionally(ctx2.getException());
}
});
thenApplyAsync(task).addListener(
nestedPromise -> {
nestedPromise.propagateResult(promise);
nestedPromise.addChild(promise);
},
promise::completeExceptionally
);
addChild(promise);
return promise;
}
private <V> @NotNull Runnable createRunnable(@NotNull PromiseCompletion<T> ctx, @NotNull Promise<V> promise, @NotNull ExceptionalFunction<T, V> task) {
private <V> @NotNull Runnable createRunnable(T result, @NotNull Promise<V> promise, @NotNull ExceptionalFunction<T, V> task) {
return () -> {
if (ctx.isError()) {
//noinspection ConstantConditions
promise.completeExceptionally(ctx.getException());
return;
}
if (promise.isCompleted()) return;
try {
V result = task.apply(ctx.getResult());
promise.complete(result);
V nextResult = task.apply(result);
promise.complete(nextResult);
} catch (Throwable e) {
promise.completeExceptionally(e);
}
};
}
@Override
public @NotNull Promise<T> logExceptions() {
return logExceptions("Exception caught in promise chain");
}
@Override
public @NotNull Promise<T> logExceptions(@NotNull String message) {
return addListener(ctx -> {
if (ctx.isError()) {
getLogger().error(message, ctx.getException());
}
});
return onError(e -> getLogger().error(message, e));
}
@Override
@@ -320,19 +295,51 @@ public abstract class AbstractPromise<T> implements Promise<T> {
}
@Override
public @NotNull Promise<T> timeout(long time, @NotNull TimeUnit unit) {
getExecutor().runAsync(() -> {
if (!isCompleted()) {
completeExceptionally(new TimeoutException("Promise timed out after " + time + " " + unit));
public @NotNull Promise<T> addListener(@Nullable Consumer<T> successListener, @Nullable Consumer<Throwable> errorListener) {
return addListener((res) -> {
if (res.isError()) {
if (errorListener != null) errorListener.accept(res.getException());
} else {
if (successListener != null) successListener.accept(res.getResult());
}
}, time, unit);
return this;
});
}
@Override
public @NotNull Promise<T> timeout(long ms) {
return timeout(ms, TimeUnit.MILLISECONDS);
public @NotNull Promise<T> onSuccess(@NotNull Consumer<T> listener) {
return addListener(listener, null);
}
@Override
public @NotNull Promise<T> onError(@NotNull Consumer<Throwable> listener) {
return addListener(null, listener);
}
@Override
public <E extends Throwable> @NotNull Promise<T> onError(@NotNull Class<E> clazz, @NotNull Consumer<E> listener) {
return onError((e) -> {
if (clazz.isAssignableFrom(e.getClass())) {
//noinspection unchecked
listener.accept((E) e);
}
});
}
@Override
public @NotNull Promise<T> onCancel(@NotNull Consumer<CancellationException> listener) {
return onError(CancellationException.class, listener);
}
@Deprecated
@Override
public @NotNull Promise<T> timeout(long time, @NotNull TimeUnit unit) {
return maxWaitTime(time, unit);
}
@Override
public @NotNull Promise<T> maxWaitTime(long time, @NotNull TimeUnit unit) {
F future = getExecutor().runAsync(() -> completeExceptionally(new TimeoutException("Promise stopped waiting after " + time + " " + unit)), time, unit);
return onError(e -> getExecutor().cancel(future));
}
private void handleCompletion(@NotNull PromiseCompletion<T> ctx) {
@@ -358,6 +365,26 @@ public abstract class AbstractPromise<T> implements Promise<T> {
return this.completion.compareAndSet(null, completion);
}
@Override
public void addChild(@NotNull Promise<?> child) {
child.onCancel((e) -> this.cancel(e.getMessage()));
}
@Override
public void propagateResult(@NotNull Promise<T> target) {
addListener(target::complete, target::completeExceptionally);
}
@Override
public void cancel() {
completeExceptionally(new CancellationException());
}
@Override
public void cancel(@NotNull String message) {
completeExceptionally(new CancellationException(message));
}
@Override
public void complete(@Nullable T result) {
handleCompletion(new PromiseCompletion<>(result));