handle null values in join promise iterator

This commit is contained in:
WhatCats
2026-03-12 13:12:46 +01:00
parent c4d596f99d
commit 23f529b8dd
9 changed files with 48 additions and 27 deletions

View File

@@ -6,7 +6,7 @@ plugins {
subprojects {
group = 'dev.tommyjs'
version = '2.5.3'
version = '2.5.4'
apply plugin: 'java-library'
apply plugin: 'com.github.johnrengelman.shadow'

View File

@@ -20,12 +20,12 @@ public class CompletionJoiner extends PromiseJoiner<Promise<?>, Void, Void, List
}
@Override
protected Void getChildKey(Promise<?> value) {
protected Void getChildKey(@NotNull Promise<?> value) {
return null;
}
@Override
protected @NotNull Promise<Void> getChildPromise(Promise<?> value) {
protected @NotNull Promise<Void> getChildPromise(@NotNull Promise<?> value) {
//noinspection unchecked
return (Promise<Void>) value;
}

View File

@@ -5,6 +5,7 @@ import dev.tommyjs.futur.promise.PromiseCompletion;
import dev.tommyjs.futur.promise.PromiseFactory;
import dev.tommyjs.futur.util.ConcurrentResultArray;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.*;
@@ -20,12 +21,12 @@ public class MappedResultJoiner<K, V> extends PromiseJoiner<Map.Entry<K, Promise
}
@Override
protected K getChildKey(Map.Entry<K, Promise<V>> entry) {
protected K getChildKey(@NotNull Map.Entry<K, Promise<V>> entry) {
return entry.getKey();
}
@Override
protected @NotNull Promise<V> getChildPromise(Map.Entry<K, Promise<V>> entry) {
protected @Nullable Promise<V> getChildPromise(@NotNull Map.Entry<K, Promise<V>> entry) {
return entry.getValue();
}

View File

@@ -6,6 +6,7 @@ import dev.tommyjs.futur.promise.PromiseCompletion;
import dev.tommyjs.futur.promise.PromiseFactory;
import dev.tommyjs.futur.util.PromiseUtil;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicInteger;
@@ -18,35 +19,40 @@ public abstract class PromiseJoiner<T, Key, Value, Result> {
this.joined = factory.unresolved();
}
protected abstract Key getChildKey(T value);
protected abstract Key getChildKey(@NotNull T value);
protected abstract @NotNull Promise<Value> getChildPromise(T value);
protected abstract @Nullable Promise<Value> getChildPromise(@NotNull T value);
protected abstract void onChildComplete(int index, Key key, @NotNull PromiseCompletion<Value> completion);
protected abstract Result getResult();
protected void join(@NotNull Iterator<T> promises) {
protected void join(@NotNull Iterator<@Nullable T> promises) {
assert !joined.isCompleted();
AtomicInteger count = new AtomicInteger();
int i = 0;
do {
if (joined.isCompleted()) {
promises.forEachRemaining(v -> getChildPromise(v).cancel());
return;
T value = promises.next();
if (value == null) {
continue;
}
T value = promises.next();
Promise<Value> p = getChildPromise(value);
if (!p.isCompleted()) {
PromiseUtil.cancelOnComplete(joined, p);
Promise<Value> promise = getChildPromise(value);
if (promise == null) {
continue;
}
if (!promise.isCompleted()) {
PromiseUtil.cancelOnComplete(joined, promise);
}
count.incrementAndGet();
Key key = getChildKey(value);
int index = i++;
p.addAsyncListener(res -> {
promise.addAsyncListener(res -> {
onChildComplete(index, key, res);
if (res.isError()) {
assert res.getException() != null;

View File

@@ -20,12 +20,12 @@ public class ResultJoiner<T> extends PromiseJoiner<Promise<T>, Void, T, List<T>>
}
@Override
protected Void getChildKey(Promise<T> value) {
protected Void getChildKey(@NotNull Promise<T> value) {
return null;
}
@Override
protected @NotNull Promise<T> getChildPromise(Promise<T> value) {
protected @NotNull Promise<T> getChildPromise(@NotNull Promise<T> value) {
return value;
}

View File

@@ -15,12 +15,12 @@ public class VoidJoiner extends PromiseJoiner<Promise<?>, Void, Void, Void> {
}
@Override
protected Void getChildKey(Promise<?> value) {
protected Void getChildKey(@NotNull Promise<?> value) {
return null;
}
@Override
protected @NotNull Promise<Void> getChildPromise(Promise<?> value) {
protected @NotNull Promise<Void> getChildPromise(@NotNull Promise<?> value) {
//noinspection unchecked
return (Promise<Void>) value;
}

View File

@@ -70,7 +70,7 @@ public abstract class AbstractPromiseFactory implements PromiseFactory {
return resolve(Collections.emptyList());
}
return new ResultJoiner<>(this, promises, expectedSize).joined();
return new ResultJoiner<V>(this, promises, expectedSize).joined();
}
@Override

View File

@@ -183,7 +183,7 @@ public interface PromiseFactory {
* @param promises the input promises
* @return the combined promise
*/
default <K, V> @NotNull Promise<Map<K, V>> combineMapped(@NotNull Map.Entry<K, Promise<V>>... promises) {
default <K, V> @NotNull Promise<Map<K, V>> combineMapped(@Nullable Map.Entry<K, Promise<V>>... promises) {
return combineMapped(Arrays.spliterator(promises));
}
@@ -298,7 +298,7 @@ public interface PromiseFactory {
* @param promises the input promises
* @return the combined promise
*/
default <V> @NotNull Promise<List<V>> combine(@NotNull Promise<V>... promises) {
default <V> @NotNull Promise<List<V>> combine(@Nullable Promise<V>... promises) {
return combine(Arrays.spliterator(promises));
}
@@ -352,7 +352,7 @@ public interface PromiseFactory {
* @param promises the input promises
* @return the combined promise
*/
default @NotNull Promise<List<PromiseCompletion<?>>> allSettled(@NotNull Promise<?>... promises) {
default @NotNull Promise<List<PromiseCompletion<?>>> allSettled(@Nullable Promise<?>... promises) {
return allSettled(Arrays.spliterator(promises));
}
@@ -398,7 +398,7 @@ public interface PromiseFactory {
* @param promises the input promises
* @return the combined promise
*/
default @NotNull Promise<Void> all(@NotNull Promise<?>... promises) {
default @NotNull Promise<Void> all(@Nullable Promise<?>... promises) {
return all(Arrays.asList(promises).iterator());
}
@@ -482,7 +482,7 @@ public interface PromiseFactory {
* @param ignoreErrors whether to ignore promises that complete exceptionally
* @return the combined promise
*/
default <V> @NotNull Promise<V> race(boolean ignoreErrors, @NotNull Promise<V>... promises) {
default <V> @NotNull Promise<V> race(boolean ignoreErrors, @Nullable Promise<V>... promises) {
return race(Arrays.asList(promises), ignoreErrors);
}
@@ -494,7 +494,7 @@ public interface PromiseFactory {
* @param promises the input promises
* @return the combined promise
*/
default <V> @NotNull Promise<V> race(@NotNull Promise<V>... promises) {
default <V> @NotNull Promise<V> race(@Nullable Promise<V>... promises) {
return race(false, promises);
}

View File

@@ -8,6 +8,7 @@ import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
@@ -116,6 +117,19 @@ public final class PromiseTests {
assert result.equals(unsizedIntStream(1000).boxed().toList());
}
@Test
public void testCombineNull() {
var result1 = promises.combine(Arrays.asList(null, null, null)).await();
var result2 = promises.combineMapped(Arrays.asList(null, null, null)).await();
var result3 = promises.combineMapped(List.of(new AbstractMap.SimpleEntry<>(null, null))).await();
var result4 = promises.combineMapped(List.of(new AbstractMap.SimpleEntry<>(null, promises.resolve(true)))).await();
assert result1.isEmpty();
assert result2.isEmpty();
assert result3.isEmpty();
assert result4.get(null) == true;
}
@Test
public void testCombineUtil() throws TimeoutException, ExecutionException, InterruptedException {
promises.all(