diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml index 1773fc7d4..0e2cd9287 100644 --- a/flink-ml-lib/pom.xml +++ b/flink-ml-lib/pom.xml @@ -138,6 +138,11 @@ under the License. test test-jar + + it.unimi.dsi + fastutil + 8.5.12 + diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/Message.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/Message.java new file mode 100644 index 000000000..36ca83269 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/Message.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.util.Bits; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +/** + * {@link Message} is responsible for encoding all information exchanged between {@link + * WorkerOperator} and {@link ServerOperator}. The message format follows this structure: + * + *

`workerId serverId stageId keyLength keys valuesLength values` + * + *

where the message fields include the worker ID, server ID, stage ID, length of the keys, keys + * themselves, length of the values, and the values. + */ +public class Message { + private static final int WORKER_ID_OFFSET = 0; + private static final int SERVER_ID_OFFSET = Integer.BYTES; + private static final int STAGE_ID_OFFSET = Integer.BYTES + SERVER_ID_OFFSET; + private static final int KVS_OFFSET = Integer.BYTES + STAGE_ID_OFFSET; + + /** The storage of message in bytes. */ + public final byte[] bytes; + + /** Constructs a message instance from the bytes. */ + public Message(byte[] bytes) { + this.bytes = bytes; + } + + /** Constructs a message instance from long keys and double values. */ + public Message(int workerId, int serverId, int stageId, long[] keys, double[] values) { + int sizeInBytes = + KVS_OFFSET + + Bits.getLongArraySizeInBytes(keys) + + Bits.getDoubleArraySizeInBytes(values); + bytes = new byte[sizeInBytes]; + Bits.putInt(bytes, WORKER_ID_OFFSET, workerId); + Bits.putInt(bytes, SERVER_ID_OFFSET, serverId); + Bits.putInt(bytes, STAGE_ID_OFFSET, stageId); + int offset = Bits.putLongArray(keys, bytes, KVS_OFFSET); + Bits.putDoubleArray(values, bytes, offset); + } + + /** Constructs a message instance from long keys and generics values. */ + public Message( + int workerId, + int serverId, + int stageId, + long[] keys, + V[] values, + TypeSerializer serializer) + throws IOException { + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + DataOutputViewStreamWrapper dataOutputViewStreamWrapper = + new DataOutputViewStreamWrapper(byteArrayOutputStream); + dataOutputViewStreamWrapper.writeInt(workerId); + dataOutputViewStreamWrapper.writeInt(serverId); + dataOutputViewStreamWrapper.writeInt(stageId); + + dataOutputViewStreamWrapper.writeInt(keys.length); + for (long key : keys) { + dataOutputViewStreamWrapper.writeLong(key); + } + dataOutputViewStreamWrapper.writeInt(values.length); + for (V value : values) { + serializer.serialize(value, dataOutputViewStreamWrapper); + } + bytes = byteArrayOutputStream.toByteArray(); + } + + /** Retrieves the keys. */ + public long[] getKeys() { + return Bits.getLongArray(bytes, KVS_OFFSET); + } + + /** Retrieves the values using the given serializer. */ + public V[] getValues(TypeSerializer serializer) throws IOException { + int numIndices = Bits.getInt(bytes, KVS_OFFSET); + int offset = KVS_OFFSET + Integer.BYTES + numIndices * Long.BYTES; + int numValues = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + + // Since the generics got erased, we use reflections to create the array. + V[] result = (V[]) Array.newInstance(serializer.createInstance().getClass(), numValues); + ByteArrayInputStream byteArrayInputStream = + new ByteArrayInputStream(bytes, offset, bytes.length - offset); + DataInputViewStreamWrapper dataInputViewStreamWrapper = + new DataInputViewStreamWrapper(byteArrayInputStream); + for (int i = 0; i < numValues; i++) { + result[i] = serializer.deserialize(dataInputViewStreamWrapper); + } + return result; + } + + /** + * Retrieves the values in double array. + * + *

Note that getting double array in this function using {@link Bits#getDoubleArray(byte[], + * int)} is faster than {@link Message#getValues} by up to 2.3X. + */ + public double[] getValuesInDoubleArray() { + int offset = KVS_OFFSET + Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES + Integer.BYTES; + return Bits.getDoubleArray(bytes, offset); + } + + /** Retrieves the worker id. */ + public int getWorkerId() { + return Bits.getInt(bytes, WORKER_ID_OFFSET); + } + + /** Sets the worker id. */ + public void setWorkerId(int workerId) { + Bits.putInt(bytes, WORKER_ID_OFFSET, workerId); + } + + /** Retrieves the server id. */ + public int getServerId() { + return Bits.getInt(bytes, SERVER_ID_OFFSET); + } + + /** Sets the server id. */ + public void setServerId(int serverId) { + Bits.putInt(bytes, SERVER_ID_OFFSET, serverId); + } + + /** Retrieves the stage id. */ + public int getStageId() { + return Bits.getInt(bytes, STAGE_ID_OFFSET); + } + + /** + * Assembles the received messages from servers according to the server id. Note that these + * message should be the responses from the same stage. + */ + public static Message assembleMessages(Iterator messageIterator) { + List messages = new ArrayList<>(); + while (messageIterator.hasNext()) { + messages.add(new Message(messageIterator.next())); + } + messages.sort(Comparator.comparingInt(Message::getServerId)); + + int numMessages = messages.size(); + int numKeys = 0, numValues = 0; + int numAssembledBytes = 0; + int workerId = -1; + int stageId = -1; + for (Message message : messages) { + if (workerId == -1) { + workerId = message.getWorkerId(); + stageId = message.getStageId(); + } + numKeys += message.getNumKeys(); + numValues += message.getNumValues(); + numAssembledBytes += message.bytes.length; + } + numAssembledBytes -= (numMessages - 1) * (KVS_OFFSET + Integer.BYTES * 2); + byte[] assembledBytes = new byte[numAssembledBytes]; + Bits.putInt(assembledBytes, WORKER_ID_OFFSET, workerId); + Bits.putInt(assembledBytes, STAGE_ID_OFFSET, stageId); + int keysOffset = KVS_OFFSET; + Bits.putInt(assembledBytes, keysOffset, numKeys); + keysOffset += Integer.BYTES; + int valuesOffset = keysOffset + numKeys * Long.BYTES; + Bits.putInt(assembledBytes, valuesOffset, numValues); + valuesOffset += Integer.BYTES; + + for (Message message : messages) { + Tuple2 keysOffsetAndLength = message.getKeysOffsetAndLength(); + System.arraycopy( + message.bytes, + keysOffsetAndLength.f0, + assembledBytes, + keysOffset, + keysOffsetAndLength.f1); + keysOffset += keysOffsetAndLength.f1; + Tuple2 valuesOffsetAndLength = message.getValuesOffSetAndLength(); + System.arraycopy( + message.bytes, + valuesOffsetAndLength.f0, + assembledBytes, + valuesOffset, + valuesOffsetAndLength.f1); + valuesOffset += valuesOffsetAndLength.f1; + } + + Message message = new Message(assembledBytes); + message.setServerId(-1); + return message; + } + + private Tuple2 getKeysOffsetAndLength() { + int start = KVS_OFFSET + Integer.BYTES; + int numBytes = Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES; + return Tuple2.of(start, numBytes); + } + + private Tuple2 getValuesOffSetAndLength() { + int start = + Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES + + KVS_OFFSET + + Integer.BYTES + + Integer.BYTES; + return Tuple2.of(start, bytes.length - start); + } + + private int getNumKeys() { + return Bits.getInt(bytes, KVS_OFFSET); + } + + private int getNumValues() { + return Bits.getInt(bytes, KVS_OFFSET + Integer.BYTES + Long.BYTES * getNumKeys()); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java new file mode 100644 index 000000000..d8d3c095c --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray; +import org.apache.flink.ml.common.ps.sarray.SharedLongArray; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Preconditions; + +import it.unimi.dsi.fastutil.doubles.DoubleArrayList; +import it.unimi.dsi.fastutil.longs.LongArrayList; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.util.Arrays; +import java.util.function.Function; + +/** + * ServerAgent resides on each worker. It serves as an agent for {@link WorkerOperator} to talk with + * {@link ServerOperator}. + */ +class ServerAgent { + /** Index of the worker that this agent resides on. */ + private final int workerId; + /** Number of servers to talk to. */ + private final int numServers; + /** Hash function to partition keys to different servers. */ + private final Function hashFunc; + /** The collector on this worker. */ + private final Output> output; + + ServerAgent( + int workerId, + int numServers, + Function hashFunc, + Output> output) { + this.workerId = workerId; + this.numServers = numServers; + this.output = output; + this.hashFunc = hashFunc; + } + + /** Pushes a key-value arrays to servers. */ + void push(SharedLongArray keys, SharedDoubleArray values, int stageId) { + Tuple2 slicedRequests = sliceRequest(keys, values); + LongArrayList[] splitKeys = slicedRequests.f0; + DoubleArrayList[] splitValues = slicedRequests.f1; + for (int serverId = 0; serverId < splitKeys.length; serverId++) { + Message message = + new Message( + workerId, + serverId, + stageId, + splitKeys[serverId].toLongArray(), + splitValues[serverId].toDoubleArray()); + output.collect(new StreamRecord<>(message.bytes)); + } + } + + /** Pulls the values from servers with the specified keys. */ + void pull(SharedLongArray keys, int stageId) { + Tuple2 slicedRequests = sliceRequest(keys, null); + LongArrayList[] splitKeys = slicedRequests.f0; + for (int serverId = 0; serverId < splitKeys.length; serverId++) { + Message message = + new Message( + workerId, + serverId, + stageId, + splitKeys[serverId].toLongArray(), + new double[0]); + output.collect(new StreamRecord<>(message.bytes)); + } + } + + /** + * Pushes the values to servers to apply all-reduce/reduce-scatter operation. + * + *

Note that the values pushed by this function are not going to update the model, but just + * perform an reduce operation. + */ + void reduce(V[] values, TypeSerializer typeSerializer, int stageId) throws IOException { + int shardSize = values.length / numServers + 1; + for (int serverId = 0; serverId < numServers; serverId++) { + int s = Math.min(serverId * shardSize, values.length); + int e = Math.min(s + shardSize, values.length); + V[] segment = Arrays.copyOfRange(values, s, e); + Message message = + new Message(workerId, serverId, stageId, new long[0], segment, typeSerializer); + output.collect(new StreamRecord<>(message.bytes)); + } + } + + /** + * Splits the push/pull request according to the given sorted keys and the corresponding values. + * + * @param keys keys of push/pull request. + * @param values the push values if not null. + * @return the split requests for each server. + */ + private Tuple2 sliceRequest( + SharedLongArray keys, @Nullable SharedDoubleArray values) { + LongArrayList[] splitKeys = new LongArrayList[numServers]; + DoubleArrayList[] splitValues = new DoubleArrayList[numServers]; + for (int i = 0; i < numServers; i++) { + splitKeys[i] = new LongArrayList(); + splitValues[i] = new DoubleArrayList(); + } + + int numDoublesPerKey = 0; + if (values != null) { + Preconditions.checkState( + values.size() % keys.size() == 0, "The length of each key should be the same."); + numDoublesPerKey = values.size() / keys.size(); + } + + long[] keyArray = keys.elements(); + for (int i = 0; i < keys.size(); i++) { + int serverId = hashFunc.apply(keyArray[i]); + splitKeys[serverId].add(keyArray[i]); + if (values != null) { + for (int j = 0; j < numDoublesPerKey; j++) { + splitValues[serverId].add(values.get(i * numDoublesPerKey + j)); + } + } + } + + return Tuple2.of(splitKeys, splitValues); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java new file mode 100644 index 000000000..4edbbf0c4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java @@ -0,0 +1,534 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.ml.common.ps.iterations.AllReduceStage; +import org.apache.flink.ml.common.ps.iterations.IterationStage; +import org.apache.flink.ml.common.ps.iterations.PullStage; +import org.apache.flink.ml.common.ps.iterations.PullStage.Aggregator; +import org.apache.flink.ml.common.ps.iterations.PushStage; +import org.apache.flink.ml.common.ps.iterations.ReduceScatterStage; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; +import org.apache.flink.ml.util.Bits; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectIterator; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingDeque; + +/** + * The server operator maintains the shared parameters. The shared parameters can be modeled as a + * collection of {key:value} pairs. By default, the keys are evenly distributed across servers + * through hash partitioning. For example, if there are two servers and the keys are {1,2,3,4,5,6}, + * then server-0 maintains keys {1,3,5} and server-1 maintains keys {2,4,6}. + * + *

The server receives push/pull/all-reduce/reduce-scatter requests from {@link WorkerOperator} + * and sends the answer request to {@link WorkerOperator}. It works closely with {@link + * ModelUpdater} in the following way: + * + *

+ * + *

Moreover, it accepts all-reduce/reduce-scatter request from workers and returns the reduced + * result to all workers. Note that the input of all-reduce/reduce-scatter operation is not going to + * be used in {@link ModelUpdater}. + * + * @param output format of model data. + */ +public class ServerOperator extends AbstractStreamOperator + implements OneInputStreamOperator, IterationListener { + /** The iterationStage list. */ + private final List stageList; + /** Number of workers to communicate with. */ + private final int numWorkers; + /** The logic to answer push/pull request from workers. */ + private final ModelUpdater modelUpdater; + /** Output tag of model data. */ + private final OutputTag modelOutputTag; + /** Index of the current server task. */ + private transient int serverId; + /** Thread pool to answer push/pull requests, to decouple the network and computation. */ + private transient ExecutorService singleThreadExecutor; + /** The future objects of thread calls in one epoch. */ + private transient List> futuresInEpoch; + /** + * The pending requests that server needs to send out responses (pull, all-reduce, + * reduce-scatter). + */ + private ListState pendingRequests; + /** + * The push request merged by stage id. We use map to store the merged push request since there + * may be consecutive pushes. + */ + private transient TreeMap accPushesByStage; + + private ListState accPushesByStageState; + + public ServerOperator( + List stageList, + int numWorkers, + ModelUpdater modelUpdater, + OutputTag modelOutputTag) { + this.stageList = stageList; + this.numWorkers = numWorkers; + this.modelUpdater = modelUpdater; + this.modelOutputTag = modelOutputTag; + } + + @Override + public void open() throws Exception { + super.open(); + this.serverId = getRuntimeContext().getIndexOfThisSubtask(); + this.singleThreadExecutor = Executors.newSingleThreadExecutor(); + this.futuresInEpoch = new ArrayList<>(); + } + + @Override + public void processElement(StreamRecord element) throws Exception { + Message message = new Message(element.getValue()); + IterationStage stage = stageList.get(message.getStageId() % stageList.size()); + if (stage instanceof PushStage) { + futuresInEpoch.add(singleThreadExecutor.submit(() -> processPushRequest(message))); + } else if (stage instanceof PullStage + || stage instanceof AllReduceStage + || stage instanceof ReduceScatterStage) { + pendingRequests.add(message.bytes); + } else { + throw new IllegalStateException( + "Illegal iteration stage: " + stage.getClass().getSimpleName() + "."); + } + } + + @SuppressWarnings("unchecked") + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector collector) throws Exception { + // Waits until the pushes are processed. + for (Future future : futuresInEpoch) { + future.get(); + } + futuresInEpoch.clear(); + // Uses the merged pushes to update model. + for (Long2ObjectOpenHashMap currentAccPush : accPushesByStage.values()) { + if (currentAccPush.size() > 0) { + // The push is not empty. + int numDoublesPerKey; + Object object = currentAccPush.values().iterator().next(); + if (object instanceof Double) { + numDoublesPerKey = 1; + } else { + numDoublesPerKey = ((double[]) object).length; + } + + ObjectIterator> objectIterator = + currentAccPush.long2ObjectEntrySet().fastIterator(); + + long[] assembledKeys = new long[currentAccPush.size()]; + double[] assembledValues = new double[currentAccPush.size() * numDoublesPerKey]; + + int idx = 0; + if (numDoublesPerKey == 1) { + while (objectIterator.hasNext()) { + Map.Entry entry = + (Map.Entry) objectIterator.next(); + assembledKeys[idx] = entry.getKey(); + assembledValues[idx] = entry.getValue(); + idx++; + } + } else { + while (objectIterator.hasNext()) { + Map.Entry entry = + (Map.Entry) objectIterator.next(); + assembledKeys[idx] = entry.getKey(); + System.arraycopy( + entry.getValue(), + 0, + assembledValues, + idx * numDoublesPerKey, + numDoublesPerKey); + idx++; + } + } + currentAccPush.clear(); + modelUpdater.update(assembledKeys, assembledValues); + } + } + + // Deals with the pending requests, which should be one of Pull, AllReduce, ReduceScatter. + Iterator requestIterator = pendingRequests.get().iterator(); + if (requestIterator.hasNext()) { + Message message = new Message(requestIterator.next()); + int stageId = message.getStageId(); + IterationStage stage = stageList.get(stageId % stageList.size()); + requestIterator = pendingRequests.get().iterator(); + if (stage instanceof PullStage) { + final int blockingQueueCapacity = 20; + LinkedBlockingDeque pullsResponse = + new LinkedBlockingDeque<>(blockingQueueCapacity); + for (byte[] bytes : pendingRequests.get()) { + singleThreadExecutor.submit( + () -> processPullRequest(new Message(bytes), pullsResponse)); + } + int numResponsesSent = 0; + while (numResponsesSent < numWorkers) { + Message response = new Message(pullsResponse.take()); + output.collect(new StreamRecord<>(response.bytes)); + numResponsesSent++; + } + } else if (stage instanceof AllReduceStage) { + processAllReduceRequest(requestIterator); + } else if (stage instanceof ReduceScatterStage) { + processReduceScatterRequest(requestIterator); + } else { + throw new IllegalStateException( + "Illegal iteration stage: " + stage.getClass().getSimpleName() + "."); + } + + pendingRequests.clear(); + } + } + + @Override + public void onIterationTerminated(Context context, Collector collector) { + Iterator modelSegments = modelUpdater.getModelSegments(); + while (modelSegments.hasNext()) { + MT modelSegment = modelSegments.next(); + output.collect(modelOutputTag, new StreamRecord<>(modelSegment)); + } + } + + @SuppressWarnings("unchecked") + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + pendingRequests = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "pendingRequests", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + + modelUpdater.initializeState(context); + + accPushesByStageState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "accPushesByStageState", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + + // Recovers accPushesByStage from a byte[] stream. + Iterator accPushesInBytes = accPushesByStageState.get().iterator(); + accPushesByStage = new TreeMap<>(); + + if (accPushesInBytes.hasNext()) { + // 4 bytes for number of stages. + byte[] meta = accPushesInBytes.next(); + int offset = 0; + int numberOfStages = Bits.getInt(meta, offset); + for (int i = 0; i < numberOfStages; i++) { + byte[] oneStageMeta = accPushesInBytes.next(); + offset = 0; + int stageId = Bits.getInt(oneStageMeta, offset); + offset += Integer.BYTES; + int sizeOfLong2ObjectMap = Bits.getInt(oneStageMeta, offset); + offset += Integer.BYTES; + int arrayLengthPerObject = Bits.getInt(oneStageMeta, offset); + Long2ObjectOpenHashMap pushes; + if (arrayLengthPerObject == 0) { + pushes = new Long2ObjectOpenHashMap(sizeOfLong2ObjectMap); + } else { + pushes = new Long2ObjectOpenHashMap(sizeOfLong2ObjectMap); + } + accPushesByStage.put(stageId, pushes); + for (int entryId = 0; entryId < sizeOfLong2ObjectMap; entryId++) { + byte[] kvInBytes = accPushesInBytes.next(); + long key = Bits.getLong(kvInBytes, 0); + if (arrayLengthPerObject == 0) { + Double value = Bits.getDouble(kvInBytes, Long.BYTES); + pushes.put(key, value); + } else { + double[] value = Bits.getDoubleArray(kvInBytes, Long.BYTES); + pushes.put(key, value); + } + } + } + } + } + + @SuppressWarnings("unchecked") + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + // Waits until the futures to finish. + for (Future future : futuresInEpoch) { + future.get(); + } + futuresInEpoch.clear(); + modelUpdater.snapshotState(context); + + accPushesByStageState.clear(); + // Writes accPushesByStage to state in the following format: + // numberOfStagesInInt, + // stageIdInInt, sizeOfLong2ObjectMapInInt, arrayLengthPerObject, key-value-long-obj... + // stageIdInInt, sizeOfLong2ObjectMapInInt, arrayLengthPerObject, key-value-long-obj... + if (accPushesByStage.size() > 0) { + int numberOfStages = accPushesByStage.size(); + byte[] meta = new byte[Integer.BYTES]; + Bits.putInt(meta, 0, numberOfStages); + accPushesByStageState.add(meta); + + for (Map.Entry entry : accPushesByStage.entrySet()) { + byte[] oneStageMeta = new byte[Integer.BYTES * 3]; + int offset = 0; + int stageId = entry.getKey(); + Bits.putInt(oneStageMeta, offset, stageId); + offset += Integer.BYTES; + int sizeOfLong2ObjectMap = entry.getValue().size(); + Bits.putInt(oneStageMeta, offset, sizeOfLong2ObjectMap); + offset += Integer.BYTES; + // 0 stands for Double, a non-zero value represents the array length. + int arrayLengthPerObject = 0; + + ObjectIterator> objectIterator = + entry.getValue().long2ObjectEntrySet().fastIterator(); + + if (objectIterator.hasNext()) { + Map.Entry oneEntry = objectIterator.next(); + if (oneEntry.getValue() instanceof double[]) { + arrayLengthPerObject = ((double[]) (oneEntry.getValue())).length; + } + Bits.putInt(oneStageMeta, offset, arrayLengthPerObject); + accPushesByStageState.add(oneStageMeta); + + accPushesByStageState.add(kvToBytes(oneEntry)); + while (objectIterator.hasNext()) { + accPushesByStageState.add(kvToBytes(objectIterator.next())); + } + } + } + } + } + + private static byte[] kvToBytes(Map.Entry obj) { + byte[] bytes; + if (obj.getValue() instanceof double[]) { + double[] value = (double[]) obj.getValue(); + bytes = new byte[Long.BYTES + Bits.getDoubleArraySizeInBytes(value)]; + Bits.putLong(bytes, 0, obj.getKey()); + Bits.putDoubleArray(value, bytes, Long.BYTES); + } else { + bytes = new byte[Long.BYTES + Double.BYTES]; + Bits.putLong(bytes, 0, obj.getKey()); + Bits.putDouble(bytes, Long.BYTES, (Double) obj.getValue()); + } + return bytes; + } + + @SuppressWarnings("unchecked") + private Object processPushRequest(Message message) throws Exception { + long[] keys = message.getKeys(); + int stageId = message.getStageId(); + double[] values = message.getValuesInDoubleArray(); + + accPushesByStage.putIfAbsent(stageId, new Long2ObjectOpenHashMap()); + Long2ObjectOpenHashMap currentAccKvs = accPushesByStage.get(stageId); + + if (keys.length != 0) { + ReduceFunction reduceFunc = + ((PushStage) stageList.get(stageId % stageList.size())).reduceFunc; + if (values.length == keys.length) { + for (int i = 0; i < keys.length; i++) { + if (currentAccKvs.containsKey(keys[i])) { + double currentVal = (Double) currentAccKvs.get(keys[i]); + currentAccKvs.put(keys[i], reduceFunc.reduce(currentVal, values[i])); + } else { + currentAccKvs.put(keys[i], (Double) values[i]); + } + } + } else { + int numDoublesPerKey = values.length / keys.length; + for (int i = 0; i < keys.length; i++) { + if (currentAccKvs.containsKey(keys[i])) { + double[] currentVal = (double[]) currentAccKvs.get(keys[i]); + for (int j = 0; j < numDoublesPerKey; j++) { + currentVal[j] = + reduceFunc.reduce( + currentVal[j], values[i * numDoublesPerKey + j]); + } + } else { + currentAccKvs.put( + keys[i], + Arrays.copyOfRange( + values, + i * numDoublesPerKey, + i * numDoublesPerKey + numDoublesPerKey)); + } + } + } + } + return new Object(); + } + + private void processPullRequest(Message message, LinkedBlockingDeque pullsResponse) { + int workerId = message.getWorkerId(); + long[] keys = message.getKeys(); + Message response; + + if (keys.length == 0) { + // No request on this server. + response = + new Message( + workerId, serverId, message.getStageId(), new long[0], new double[0]); + } else { + double[] pulledValues = modelUpdater.get(keys); + Preconditions.checkState(pulledValues.length % keys.length == 0); + int numDoublesPerKey = pulledValues.length / keys.length; + + double[] aggregatedPullValues = null; + Aggregator aggregator = + ((PullStage) (stageList.get(message.getStageId() % stageList.size()))) + .aggregator; + if (aggregator != null) { + // Processes the pulled values if the aggregator is not null. + double[] tmp = new double[numDoublesPerKey]; + for (int i = 0; i < keys.length; i++) { + System.arraycopy(pulledValues, i * numDoublesPerKey, tmp, 0, numDoublesPerKey); + aggregatedPullValues = aggregator.add(tmp, aggregatedPullValues); + } + } else { + aggregatedPullValues = pulledValues; + } + + response = + new Message( + workerId, + serverId, + message.getStageId(), + new long[0], + aggregatedPullValues); + } + while (!pullsResponse.offer(response.bytes)) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + + @SuppressWarnings("unchecked") + private void processAllReduceRequest(Iterator requests) throws Exception { + byte[] request = requests.next(); + Message message = new Message(request); + int stageId = message.getStageId(); + AllReduceStage stage = (AllReduceStage) stageList.get(stageId % stageList.size()); + V[] reducedResult = message.getValues(stage.typeSerializer); + ReduceFunction reduceFunction = stage.reducer; + + while (requests.hasNext()) { + message = new Message(requests.next()); + reducedResult = + reduceFunction.reduce(message.getValues(stage.typeSerializer), reducedResult); + } + message = + new Message( + -1, serverId, stageId, new long[0], reducedResult, stage.typeSerializer); + + for (int workerId = 0; workerId < numWorkers; workerId++) { + message.setWorkerId(workerId); + output.collect(new StreamRecord<>(message.bytes)); + } + } + + @SuppressWarnings("unchecked") + private void processReduceScatterRequest(Iterator requests) throws Exception { + byte[] request = requests.next(); + Message message = new Message(request); + int stageId = message.getStageId(); + ReduceScatterStage stage = + (ReduceScatterStage) stageList.get(stageId % stageList.size()); + V[] reducedResult = message.getValues(stage.typeSerializer); + ReduceFunction reduceFunction = stage.reducer; + + while (requests.hasNext()) { + message = new Message(requests.next()); + reducedResult = + reduceFunction.reduce(message.getValues(stage.typeSerializer), reducedResult); + } + + int[] recvCounts = stage.recvCounts; + int totalCnt = Arrays.stream(recvCounts).sum(); + int shardSize = totalCnt / getRuntimeContext().getNumberOfParallelSubtasks() + 1; + int sliceStart = Math.min(serverId * shardSize, totalCnt); + int sliceEnd = Math.min(sliceStart + shardSize, totalCnt); + + int s = 0; + int e; + for (int workerId = 0; workerId < numWorkers; workerId++) { + e = recvCounts[workerId] + s; + + int intersectionStart = Math.max(s, sliceStart); + int interSectionEnd = Math.min(e, sliceEnd); + int copyStart = 0, copyEnd = 0; + if (interSectionEnd > intersectionStart) { + copyStart = intersectionStart - sliceStart; + copyEnd = interSectionEnd - sliceStart; + } + message = + new Message( + workerId, + serverId, + stageId, + new long[0], + Arrays.copyOfRange(reducedResult, copyStart, copyEnd), + stage.typeSerializer); + output.collect(new StreamRecord<>(message.bytes)); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java new file mode 100644 index 000000000..cf935a6bc --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java @@ -0,0 +1,420 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.iterations.AllReduceStage; +import org.apache.flink.ml.common.ps.iterations.IterationStage; +import org.apache.flink.ml.common.ps.iterations.IterationStageList; +import org.apache.flink.ml.common.ps.iterations.MLSession; +import org.apache.flink.ml.common.ps.iterations.ProcessStage; +import org.apache.flink.ml.common.ps.iterations.PullStage; +import org.apache.flink.ml.common.ps.iterations.PushStage; +import org.apache.flink.ml.common.ps.iterations.ReduceScatterStage; +import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray; +import org.apache.flink.ml.common.ps.sarray.SharedLongArray; +import org.apache.flink.ml.common.ps.utils.ProxySideOutput; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; + +import java.util.Iterator; +import java.util.function.Function; + +/** + * The worker operator that executes the iterative machine learning process following {@link + * IterationStageList}. + * + *

In detail, the worker operator is responsible for the following: + * + *

+ */ +public class WorkerOperator extends AbstractStreamOperator + implements TwoInputStreamOperator, IterationListener { + /** The user defined iteration logic. */ + private final IterationStageList iterationStages; + /** + * Iteration id in terms of {@link IterationStageList}. When we finished processing all stages + * in stageList, the iteration id increments by one. + */ + private int iterationId; + + /** The id of the stages to execute in iterationStages. */ + private int nextStageToExecute = 0; + + private ListState nextStageToExecuteState; + + /** The agent for each worker to talk with servers. */ + private transient ServerAgent serverAgent; + /** Number of servers that this worker needs to talk to. */ + private final int numServers; + /** The hash function to distribute keys to servers. */ + private transient Function hashFunc; + + /** The cached training data. */ + private ListStateWithCache
trainDataState; + + /** + * Number of segments received from servers for the current request. For each request, a worker + * should receive one segment from each server. + */ + private int numSegmentsReceived = 0; + + private ListState numSegmentsReceivedState; + + /** + * The memory store for pull answer. For a pull request, each received segment will be filled to + * the user provided buffer. + */ + private double[] pulledResult; + + private ListState pulledResultState; + + /** The state store for the all-reduce/reduce-scatter results. */ + private ListState reducedResult; + + public WorkerOperator(IterationStageList iterationStages, int numServers) { + this.iterationStages = iterationStages; + this.numServers = numServers; + } + + @Override + public void open() { + int workerId = getRuntimeContext().getIndexOfThisSubtask(); + int numWorkers = getRuntimeContext().getNumberOfParallelSubtasks(); + this.hashFunc = key -> (int) (Math.abs(key % numServers)); + this.serverAgent = new ServerAgent(workerId, numServers, hashFunc, output); + iterationStages.session.setWorldInfo(workerId, numWorkers); + iterationStages.session.setOutput(new ProxySideOutput(output)); + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector collector) throws Exception { + if (epochWatermark == 0) { + iterationStages.session.setInputData(new ResettableTrainDataIterator<>(trainDataState)); + nextStageToExecute = processIterationStages(nextStageToExecute, iterationStages); + } + } + + @Override + public void onIterationTerminated(Context context, Collector collector) { + trainDataState.clear(); + } + + @Override + public void processElement1(StreamRecord
streamRecord) throws Exception { + trainDataState.add(streamRecord.getValue()); + } + + @Override + public void processElement2(StreamRecord streamRecord) throws Exception { + Message message = new Message(streamRecord.getValue()); + IterationStage stage = + iterationStages.stageList.get( + nextStageToExecute % iterationStages.stageList.size()); + + boolean proceedToNextStage; + if (stage instanceof PullStage) { + proceedToNextStage = onPullResponse(message, (PullStage) stage); + } else if (stage instanceof AllReduceStage) { + proceedToNextStage = onAllReduceResponse(message, (AllReduceStage) stage); + } else if (stage instanceof ReduceScatterStage) { + proceedToNextStage = onReduceScatterResponse(message, (ReduceScatterStage) stage); + } else { + throw new IllegalStateException( + "Illegal stage type: %s" + stage.getClass().getSimpleName() + "."); + } + + if (proceedToNextStage) { + nextStageToExecute++; + nextStageToExecute = processIterationStages(nextStageToExecute, iterationStages); + } + } + + private boolean onPullResponse(Message message, PullStage pullStage) { + numSegmentsReceived++; + double[] segment = message.getValuesInDoubleArray(); + if (segment.length != 0) { + if (pullStage.aggregator != null) { + if (pulledResult.length == 0) { + pulledResult = segment; + } else { + pulledResult = pullStage.aggregator.merge(segment, pulledResult); + } + } else { + SharedLongArray keys = pullStage.keys.get(); + SharedDoubleArray values = pullStage.values.get(); + int serverId = message.getServerId(); + long[] keysArray = keys.elements(); + + if (pulledResult.length == 0) { + pulledResult = values.elements(); + } + + int numDoublesPerKey = values.size() / keys.size(); + // Copy the response from one server to the result array. + int idxInLocalPull = 0; + for (int i = 0; i < keys.size(); i++) { + if (hashFunc.apply(keysArray[i]) == serverId) { + System.arraycopy( + segment, + idxInLocalPull * numDoublesPerKey, + pulledResult, + i * numDoublesPerKey, + numDoublesPerKey); + idxInLocalPull++; + } + } + } + } + + if (numSegmentsReceived == numServers) { + SharedDoubleArray pullPlaceHolder = pullStage.values.get(); + System.arraycopy( + pulledResult, 0, pullPlaceHolder.elements(), 0, pullPlaceHolder.size()); + + pulledResult = new double[0]; + numSegmentsReceived = 0; + return true; + } + return false; + } + + private boolean onAllReduceResponse(Message message, AllReduceStage allReduceStage) + throws Exception { + numSegmentsReceived++; + reducedResult.add(message.bytes); + + if (numSegmentsReceived == numServers) { + Message assembled = Message.assembleMessages(reducedResult.get().iterator()); + V[] reduceResult = assembled.getValues(allReduceStage.typeSerializer); + System.arraycopy(reduceResult, 0, allReduceStage.recvBuf.get(), 0, reduceResult.length); + reducedResult.clear(); + numSegmentsReceived = 0; + return true; + } + return false; + } + + private boolean onReduceScatterResponse( + Message message, ReduceScatterStage reduceScatterStage) throws Exception { + numSegmentsReceived++; + reducedResult.add(message.bytes); + + if (numSegmentsReceived == numServers) { + Message assembled = Message.assembleMessages(reducedResult.get().iterator()); + V[] reduceResult = assembled.getValues(reduceScatterStage.typeSerializer); + System.arraycopy( + reduceResult, 0, reduceScatterStage.recvBuf.get(), 0, reduceResult.length); + reducedResult.clear(); + numSegmentsReceived = 0; + return true; + } + return false; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + trainDataState = + new ListStateWithCache<>( + (getOperatorConfig().getTypeSerializerIn(0, getClass().getClassLoader())), + getContainingTask(), + getRuntimeContext(), + context, + config.getOperatorID()); + + numSegmentsReceivedState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("numSegmentsReceivedState", Types.INT)); + numSegmentsReceived = + OperatorStateUtils.getUniqueElement( + numSegmentsReceivedState, "numSegmentsReceivedState") + .orElse(0); + + nextStageToExecuteState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("nextStageToExecuteState", Types.INT)); + + nextStageToExecute = + OperatorStateUtils.getUniqueElement( + nextStageToExecuteState, "nextStageToExecuteState") + .orElse(0); + + iterationStages.session.initializeState(context); + + pulledResultState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "pulledResultState", + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO)); + pulledResult = + OperatorStateUtils.getUniqueElement(pulledResultState, "pulledResultState") + .orElse(new double[0]); + + reducedResult = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "reducedResult", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + + numSegmentsReceivedState.clear(); + numSegmentsReceivedState.add(numSegmentsReceived); + + nextStageToExecuteState.clear(); + nextStageToExecuteState.add(nextStageToExecute); + + trainDataState.snapshotState(context); + iterationStages.session.snapshotState(context); + + pulledResultState.clear(); + pulledResultState.add(pulledResult); + } + + /** + * Processes the stages described in the given iterationStages from the given nextStage id. This + * function processes the stages until it meets a {@link PullStage}, {@link AllReduceStage} or + * {@link ReduceScatterStage}. + * + * @param nextStageToExecute id of the next stage to execute in the given iteration stages. + * @param iterationStages iteration stages used to describe the training logic. + * @return the id of the next stage to execute. + */ + @SuppressWarnings("unchecked") + private int processIterationStages( + int nextStageToExecute, IterationStageList iterationStages) throws Exception { + while (true) { + if (nextStageToExecute > 0 + && nextStageToExecute % iterationStages.stageList.size() == 0) { + iterationId = nextStageToExecute / iterationStages.stageList.size(); + iterationStages.session.setIterationId(iterationId); + if (iterationStages.shouldTerminate.apply(iterationStages.session)) { + return -1; + } + } + IterationStage stage = + iterationStages.stageList.get( + nextStageToExecute % iterationStages.stageList.size()); + + // We are not incrementing nextStageToExecute for + // PullStage/AllReduceStage/ReduceScatterStage, since we + // need to wait for response from servers. + if (stage instanceof PullStage) { + PullStage pullStage = ((PullStage) stage); + serverAgent.pull(pullStage.keys.get(), nextStageToExecute); + return nextStageToExecute; + + } else if (stage instanceof AllReduceStage) { + AllReduceStage allReduceStage = (AllReduceStage) stage; + if (iterationId % allReduceStage.executionInterval == 0) { + serverAgent.reduce( + allReduceStage.sendBuf.get(), + allReduceStage.typeSerializer, + nextStageToExecute); + return nextStageToExecute; + } else { + nextStageToExecute++; + } + + } else if (stage instanceof ReduceScatterStage) { + ReduceScatterStage reduceScatterStage = (ReduceScatterStage) stage; + if (iterationId % reduceScatterStage.executionInterval == 0) { + serverAgent.reduce( + reduceScatterStage.sendBuf.get(), + reduceScatterStage.typeSerializer, + nextStageToExecute); + return nextStageToExecute; + } else { + nextStageToExecute++; + } + } else if (stage instanceof PushStage) { + PushStage pushStage = (PushStage) stage; + serverAgent.push(pushStage.keys.get(), pushStage.values.get(), nextStageToExecute); + nextStageToExecute++; + } else if (stage instanceof ProcessStage) { + ((ProcessStage) stage).process(iterationStages.session); + nextStageToExecute++; + } else { + throw new IllegalStateException( + "Illegal type of IterationStage: + " + + stage.getClass().getSimpleName() + + "."); + } + } + } + + /** A resettable iterator for {@link ListStateWithCache}. */ + private static class ResettableTrainDataIterator implements ResettableIterator { + private final ListStateWithCache data; + private Iterator dataIterator; + + public ResettableTrainDataIterator(ListStateWithCache data) throws Exception { + this.data = data; + this.dataIterator = data.get().iterator(); + } + + @Override + public void reset() { + try { + this.dataIterator = data.get().iterator(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public boolean hasNext() { + return dataIterator.hasNext(); + } + + @Override + public T next() { + return dataIterator.next(); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/AllReduceStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/AllReduceStage.java new file mode 100644 index 000000000..c0d3cdf0b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/AllReduceStage.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.function.SerializableSupplier; + +import java.util.function.Supplier; + +/** + * This iteration stage is designed to perform an all-reduce operation on the specified array in a + * distributed setting. + * + *

Users can specify how often this operation is conducted by setting the value of the + * "executionInterval" parameter, which determines the frequency of the all-reduce stage. For + * example, if the value of executionInterval is set to 5, the all-reduce stage will be executed + * every 5 iterations. + */ +public final class AllReduceStage implements IterationStage { + public final Supplier sendBuf; + public final Supplier recvBuf; + public final ReduceFunction reducer; + public final TypeSerializer typeSerializer; + public final int executionInterval; + + public AllReduceStage( + SerializableSupplier sendBuf, + SerializableSupplier recvBuf, + ReduceFunction reducer, + TypeSerializer typeSerializer, + int executionInterval) { + this.sendBuf = sendBuf; + this.recvBuf = recvBuf; + this.reducer = reducer; + this.typeSerializer = typeSerializer; + this.executionInterval = executionInterval; + } + + public AllReduceStage( + SerializableSupplier sendBuf, + SerializableSupplier recvBuf, + ReduceFunction reducer, + TypeSerializer typeSerializer) { + this(sendBuf, recvBuf, reducer, typeSerializer, 1); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStage.java new file mode 100644 index 000000000..d0f23a774 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStage.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import java.io.Serializable; + +/** + * Iterative machine learning training usually incurs local computation step (e.g., computing + * gradients) and global communication step (e.g., all-reduce and parameter servers to aggregate the + * updates from workers). + * + *

To describe the above iteration training process, we model the training process as a sequence + * of iteration stages. An iteration stage could be either local computation or global + * communication. + */ +public interface IterationStage extends Serializable {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStageList.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStageList.java new file mode 100644 index 000000000..e1cd23b7b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/IterationStageList.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableFunction; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +/** + * A list of iteration stages to express the logic of an iterative machine learning process. + * + *

Note that there should be at least one stage (e.g., {@link PullStage}, {@link AllReduceStage} + * or {@link ReduceScatterStage}) that needs to wait for responses from servers. + */ +public class IterationStageList implements Serializable { + /** The session on each worker. */ + public final T session; + /** The termination criteria. */ + public Function shouldTerminate; + /** The stage list that describes the iterative process. */ + public List stageList; + + public IterationStageList(T session) { + this.stageList = new ArrayList<>(); + this.session = session; + } + + /** Sets the criteria of termination. */ + public IterationStageList setTerminationCriteria( + SerializableFunction shouldTerminate) { + boolean waitServer = false; + for (IterationStage stage : stageList) { + if (stage instanceof PullStage + || stage instanceof AllReduceStage + || stage instanceof ReduceScatterStage) { + waitServer = true; + break; + } + } + Preconditions.checkState( + waitServer, + String.format( + "There should be at least one stage that needs to receive response from servers (i.e., %s, %s, %s).\n", + PullStage.class.getSimpleName(), + AllReduceStage.class.getSimpleName(), + ReduceScatterStage.class.getSimpleName())); + this.shouldTerminate = shouldTerminate; + return this; + } + + /** Adds an iteration stage into the stage list. */ + public IterationStageList addStage(IterationStage stage) { + stageList.add(stage); + return this; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSession.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSession.java new file mode 100644 index 000000000..799a5cb6a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSession.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.ml.common.ps.WorkerOperator; +import org.apache.flink.ml.common.ps.utils.ProxySideOutput; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.util.OutputTag; + +import java.io.Serializable; +import java.util.List; + +/** + * Stores the session information that is alive during the training process on {@link + * WorkerOperator}. Note that the session information will be updated by each {@link + * IterationStage}. + * + *

Subclasses should take care of the snapshot of object stored in {@link MLSession} if the + * object satisfies that: the write-process is followed by a {@link PullStage}/{@link + * AllReduceStage}/{@link ReduceScatterStage}, which is later again read by other stages. + */ +public interface MLSession extends Serializable { + /** Sets the current iteration ID. */ + default void setIterationId(int iterationId) {} + + /** Sets the worker id and total number of workers. */ + default void setWorldInfo(int workerId, int numWorkers) {} + + /** Sets the training data. */ + default void setInputData(ResettableIterator inputData) {} + + /** Sets the collector that users can output records to downstream tasks. */ + default void setOutput(ProxySideOutput collector) {} + + /** + * Retrieves the output tags from the {@link MLSession} which can be used to output records from + * the worker operator. + */ + default List> getOutputTags() { + return null; + } + + /** Recovers from state. */ + default void initializeState(StateInitializationContext context) throws Exception {} + + /** Snapshots to state. */ + default void snapshotState(StateSnapshotContext context) throws Exception {} +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSessionImpl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSessionImpl.java new file mode 100644 index 000000000..317f7cb66 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/MLSessionImpl.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.util.OutputTag; + +import java.util.List; + +/** + * The default implementation of {@link MLSession}. + * + * @param

Data type of input data. + */ +public class MLSessionImpl
implements MLSession { + /** Current iteration id. */ + public int iterationId; + /** Index of this worker. */ + public int workerId; + /** Number of workers in total for this distributed ML job. */ + public int numWorkers; + /** The input data. */ + public ResettableIterator
inputData; + + public List> outputTags; + + /** Constructs an instance with side outputs. */ + public MLSessionImpl(List> outputTags) { + this.outputTags = outputTags; + } + + /** Constructs an instance without side outputs. */ + public MLSessionImpl() { + this(null); + } + + @Override + public List> getOutputTags() { + return outputTags; + } + + @Override + public void setIterationId(int iterationId) { + this.iterationId = iterationId; + } + + @Override + public void setWorldInfo(int workerId, int numWorkers) { + this.workerId = workerId; + this.numWorkers = numWorkers; + } + + @Override + public void setInputData(ResettableIterator inputData) { + this.inputData = (ResettableIterator
) inputData; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ProcessStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ProcessStage.java new file mode 100644 index 000000000..8c8810699 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ProcessStage.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +/** + * A local computation stage of the training process. The input and output of {@link ProcessStage} + * can be accessed via {@link MLSession}. + * + * @param Type of the training data. + */ +public abstract class ProcessStage implements IterationStage { + /** + * Does a local computation logic using the information from session. Example stages could be + * computing gradients. + */ + public abstract void process(T session) throws Exception; +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PullStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PullStage.java new file mode 100644 index 000000000..8f86c5e5c --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PullStage.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray; +import org.apache.flink.ml.common.ps.sarray.SharedLongArray; +import org.apache.flink.util.function.SerializableSupplier; + +import java.io.Serializable; +import java.util.function.Supplier; + +/** + * An iteration stage that aggregates data from servers using keys as {@code PullStage#keys#get()} + * and stores the aggregated values by {@code PullStage#values#get()}. + * + *

If the aggregator is null, we simply pull those values specified by the keys. + */ +public final class PullStage implements IterationStage { + public final Supplier keys; + public final Supplier values; + public final Aggregator aggregator; + + public PullStage( + SerializableSupplier keys, + SerializableSupplier values) { + this(keys, values, null); + } + + public PullStage( + SerializableSupplier keys, + SerializableSupplier values, + Aggregator aggregator) { + this.keys = keys; + this.values = values; + this.aggregator = aggregator; + } + + /** + * An Aggregator is used to aggregate a set of input elements into a single accumulator. + * + * @param The type of the input elements. + * @param The type of the accumulator. + */ + @Internal + public interface Aggregator extends Serializable { + + /** + * Adds a new input element to the given accumulator and returns the updated accumulator. + * + * @param in The input element to add. + * @param acc The accumulator to update. + * @return The updated accumulator. + */ + ACC add(IN in, ACC acc); + + /** + * Merges two accumulators and returns the result. + * + * @param acc1 The first accumulator to merge. + * @param acc2 The second accumulator to merge. + * @return The merged accumulator. + */ + ACC merge(ACC acc1, ACC acc2); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PushStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PushStage.java new file mode 100644 index 000000000..3abec6190 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/PushStage.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray; +import org.apache.flink.ml.common.ps.sarray.SharedLongArray; +import org.apache.flink.util.function.SerializableSupplier; + +import java.util.function.Supplier; + +/** + * An iteration stage that push (indices, values) to servers. User can specify how values from + * different workers are merged via {@code PushStage#reduceFunc}. By default, the values are summed + * from different workers. + * + *

Note that the length of the values array must be divisible by the length of the keys array. + * Additionally, each value corresponding to a given key must have the same length. For instance, + * considering the keys {1, 4} and values {1,2,3,4,5,6}, the value at index 1 would be {1,2,3}, and + * the value at index 4 would be {4,5,6}. + */ +public class PushStage implements IterationStage { + public final Supplier keys; + public final Supplier values; + + /** The function to reduce the pushes from all workers. For gradient descent based methods, */ + public final ReduceFunction reduceFunc; + + public PushStage( + SerializableSupplier keys, + SerializableSupplier values) { + this(keys, values, Double::sum); + } + + public PushStage( + SerializableSupplier keys, + SerializableSupplier values, + ReduceFunction reduceFunc) { + this.keys = keys; + this.values = values; + this.reduceFunc = reduceFunc; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ReduceScatterStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ReduceScatterStage.java new file mode 100644 index 000000000..c660de285 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/iterations/ReduceScatterStage.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.iterations; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableSupplier; + +import java.util.function.Supplier; + +/** + * This iteration stage is designed to perform an reduce-scatter operation on the specified array in + * a distributed setting. + * + *

Users can specify how often this operation is conducted by setting the value of the + * "executionInterval" parameter, which determines the frequency of the all-reduce stage. For + * example, if the value of executionInterval is set to 5, the all-reduce stage will be executed + * every 5 iterations. + */ +public final class ReduceScatterStage implements IterationStage { + public final Supplier sendBuf; + public final Supplier recvBuf; + /** The number of elements each worker receives. */ + public int[] recvCounts; + + public final ReduceFunction reducer; + public final TypeSerializer typeSerializer; + + public final int executionInterval; + + public ReduceScatterStage( + SerializableSupplier sendBuf, + SerializableSupplier recvBuf, + int[] recvCounts, + ReduceFunction reducer, + TypeSerializer typeSerializer, + int executionInterval) { + this.sendBuf = sendBuf; + this.recvBuf = recvBuf; + this.recvCounts = Preconditions.checkNotNull(recvCounts); + this.reducer = reducer; + this.typeSerializer = typeSerializer; + this.executionInterval = executionInterval; + } + + public ReduceScatterStage( + SerializableSupplier sendBuf, + SerializableSupplier recvBuf, + int[] recvCounts, + ReduceFunction reducer, + TypeSerializer typeSerializer) { + this(sendBuf, recvBuf, recvCounts, reducer, typeSerializer, 1); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedDoubleArray.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedDoubleArray.java new file mode 100644 index 000000000..4a7aa24b0 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedDoubleArray.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.sarray; + +import it.unimi.dsi.fastutil.doubles.DoubleArrayList; + +import java.io.Serializable; + +/** A resizable double array that can be shared among different iterations for memory efficiency. */ +public class SharedDoubleArray implements Serializable { + + /** The underlying DoubleArrayList that holds the elements. */ + private final DoubleArrayList doubles; + + /** + * Constructs a new SDArray from the given double array. + * + * @param array the double array to wrap + */ + public SharedDoubleArray(double[] array) { + doubles = DoubleArrayList.wrap(array); + } + + /** + * Constructs a new SDArray with the given initial capacity. + * + * @param capacity the initial capacity + */ + public SharedDoubleArray(int capacity) { + doubles = new DoubleArrayList(capacity); + } + + /** Constructs a new empty SDArray. */ + public SharedDoubleArray() { + doubles = new DoubleArrayList(); + } + + /** + * Returns the element at the specified index. + * + * @param index the index of the element to return + * @return the element at the specified index + */ + public double get(int index) { + return doubles.getDouble(index); + } + + /** + * Appends the specified element to the end of this array. + * + * @param v the element to add + */ + public void add(double v) { + doubles.add(v); + } + + /** + * Appends all the elements from the specified double array to the end of this array. + * + * @param src the double array to append + */ + public void addAll(double[] src) { + int sizeBefore = size(); + doubles.size(sizeBefore + src.length); + System.arraycopy(src, 0, elements(), sizeBefore, src.length); + } + + /** + * Returns the number of valid elements in this array. + * + * @return the number of valid elements in this array + */ + public int size() { + return doubles.size(); + } + + /** + * Sets the size of the array to the provided size. If the new size is larger than the current + * size, the new allocated memory are filled with zero. + * + * @param size the new size of the array + */ + public void size(int size) { + doubles.size(size); + } + + /** Clears the elements in this array. Note that the memory is not recycled. */ + public void clear() { + doubles.clear(); + } + + /** + * Returns a double array containing all the elements in this array. Only the first {@link + * SharedDoubleArray#size()} elements are valid. + * + * @return a double array containing the all the elements in this array + */ + public double[] elements() { + return doubles.elements(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedLongArray.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedLongArray.java new file mode 100644 index 000000000..d193890da --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/sarray/SharedLongArray.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.sarray; + +import it.unimi.dsi.fastutil.longs.LongArrayList; + +import java.io.Serializable; + +/** A resizable long array that can be shared among different iterations for memory efficiency. */ +public class SharedLongArray implements Serializable { + + /** The underlying LongArrayList that holds the elements. */ + private final LongArrayList longs; + + /** + * Constructs a new SLArray from the given long array. + * + * @param array the long array to wrap + */ + public SharedLongArray(long[] array) { + longs = LongArrayList.wrap(array); + } + + /** + * Constructs a new SLArray with the given initial capacity. + * + * @param capacity the initial capacity + */ + public SharedLongArray(int capacity) { + longs = new LongArrayList(capacity); + } + + /** Constructs a new empty SLArray. */ + public SharedLongArray() { + longs = new LongArrayList(); + } + + /** + * Returns the element at the specified index. + * + * @param index the index of the element to return + * @return the element at the specified index + */ + public long get(int index) { + return longs.getLong(index); + } + + /** + * Appends the specified element to the end of this array. + * + * @param v the element to add + */ + public void add(long v) { + longs.add(v); + } + + /** + * Appends all the elements from the specified long array to the end of this array. + * + * @param src the long array to append + */ + public void addAll(long[] src) { + int sizeBefore = size(); + longs.size(sizeBefore + src.length); + System.arraycopy(src, 0, elements(), sizeBefore, src.length); + } + + /** + * Returns the number of valid elements in this array. + * + * @return the number of valid elements in this array + */ + public int size() { + return longs.size(); + } + + /** + * Resizes this array to the specified size. Sets the size of the array to the provided size. If + * the new size is larger than the current size, the new allocated memory are filled with zero. + * + * @param size the new size of the array + */ + public void size(int size) { + longs.size(size); + } + + /** Clears the elements in this array. Note that the memory is not recycled. */ + public void clear() { + longs.clear(); + } + + /** + * Returns a long array containing the valid elements in this array. Only the first {@link + * SharedLongArray#size()} elements are valid. + * + * @return a long array containing the valid elements in this array + */ + public long[] elements() { + return longs.elements(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapSerializer.java new file mode 100644 index 000000000..3e2d3b920 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapSerializer.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; + +import java.io.IOException; +import java.util.Map; + +/** TypeSerializer for {@link Long2DoubleOpenHashMap}. */ +public class Long2DoubleOpenHashMapSerializer extends TypeSerializer { + + public static final Long2DoubleOpenHashMapSerializer INSTANCE = + new Long2DoubleOpenHashMapSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer duplicate() { + return INSTANCE; + } + + @Override + public Long2DoubleOpenHashMap createInstance() { + return new Long2DoubleOpenHashMap(); + } + + @Override + public Long2DoubleOpenHashMap copy(Long2DoubleOpenHashMap from) { + return new Long2DoubleOpenHashMap(from); + } + + @Override + public Long2DoubleOpenHashMap copy(Long2DoubleOpenHashMap from, Long2DoubleOpenHashMap reuse) { + return new Long2DoubleOpenHashMap(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Long2DoubleOpenHashMap map, DataOutputView target) throws IOException { + target.writeInt(map.size()); + for (Map.Entry entry : map.entrySet()) { + target.writeLong(entry.getKey()); + target.writeDouble(entry.getValue()); + } + } + + @Override + public Long2DoubleOpenHashMap deserialize(DataInputView source) throws IOException { + int numEntries = source.readInt(); + Long2DoubleOpenHashMap map = new Long2DoubleOpenHashMap(numEntries); + for (int i = 0; i < numEntries; i++) { + long k = source.readLong(); + double v = source.readDouble(); + map.put(k, v); + } + return map; + } + + @Override + public Long2DoubleOpenHashMap deserialize(Long2DoubleOpenHashMap reuse, DataInputView source) + throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + int numEntries = source.readInt(); + target.writeInt(numEntries); + for (int i = 0; i < numEntries; ++i) { + target.writeLong(source.readLong()); + target.writeDouble(source.readDouble()); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + return true; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new Long2DoubleOpenHashMapSnapshot(); + } + + private static final class Long2DoubleOpenHashMapSnapshot + extends SimpleTypeSerializerSnapshot { + public Long2DoubleOpenHashMapSnapshot() { + super(() -> Long2DoubleOpenHashMapSerializer.INSTANCE); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapTypeInfo.java new file mode 100644 index 000000000..4acfafdc1 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2DoubleOpenHashMapTypeInfo.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; + +/** TypeInformation for {@link Long2DoubleOpenHashMap}. */ +public class Long2DoubleOpenHashMapTypeInfo extends TypeInformation { + + public static Long2DoubleOpenHashMapTypeInfo instance = new Long2DoubleOpenHashMapTypeInfo(); + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 1; + } + + @Override + public int getTotalFields() { + return 1; + } + + @Override + public Class getTypeClass() { + return Long2DoubleOpenHashMap.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer createSerializer(ExecutionConfig config) { + return Long2DoubleOpenHashMapSerializer.INSTANCE; + } + + @Override + public String toString() { + return "Long2DoubleOpenHashMap Type"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + return true; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof Long2DoubleOpenHashMapTypeInfo; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapSerializer.java new file mode 100644 index 000000000..12f250d7d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapSerializer.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.typeinfo; + +import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.util.Preconditions; + +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +/** + * TypeSerializer for {@link Long2ObjectOpenHashMap}. + * + * @param The type of elements in the Long2ObjectOpenHashMap. + */ +public class Long2ObjectOpenHashMapSerializer extends TypeSerializer> { + + private final TypeSerializer elementSerializer; + + public Long2ObjectOpenHashMapSerializer(TypeSerializer elementSerializer) { + this.elementSerializer = Preconditions.checkNotNull(elementSerializer); + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer> duplicate() { + return new Long2ObjectOpenHashMapSerializer<>(elementSerializer.duplicate()); + } + + @Override + public Long2ObjectOpenHashMap createInstance() { + return new Long2ObjectOpenHashMap<>(); + } + + @Override + public Long2ObjectOpenHashMap copy(Long2ObjectOpenHashMap from) { + return new Long2ObjectOpenHashMap<>(from); + } + + @Override + public Long2ObjectOpenHashMap copy( + Long2ObjectOpenHashMap from, Long2ObjectOpenHashMap reuse) { + return new Long2ObjectOpenHashMap<>(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Long2ObjectOpenHashMap map, DataOutputView target) throws IOException { + target.writeInt(map.size()); + for (Map.Entry entry : map.entrySet()) { + target.writeLong(entry.getKey()); + elementSerializer.serialize(entry.getValue(), target); + } + } + + @Override + public Long2ObjectOpenHashMap deserialize(DataInputView source) throws IOException { + int numEntries = source.readInt(); + Long2ObjectOpenHashMap map = new Long2ObjectOpenHashMap<>(numEntries); + for (int i = 0; i < numEntries; i++) { + long k = source.readLong(); + T v = elementSerializer.deserialize(source); + map.put(k, v); + } + return map; + } + + @Override + public Long2ObjectOpenHashMap deserialize( + Long2ObjectOpenHashMap reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + int numEntries = source.readInt(); + target.writeInt(numEntries); + for (int i = 0; i < numEntries; ++i) { + target.writeLong(source.readLong()); + elementSerializer.copy(source, target); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + Long2ObjectOpenHashMapSerializer that = (Long2ObjectOpenHashMapSerializer) o; + return Objects.equals(elementSerializer, that.elementSerializer); + } + + @Override + public int hashCode() { + return Objects.hash(elementSerializer != null ? elementSerializer.hashCode() : 0); + } + + @Override + public TypeSerializerSnapshot> snapshotConfiguration() { + return new Long2ObjectOpenHashMapSnapshot<>(this); + } + + private static final class Long2ObjectOpenHashMapSnapshot + extends CompositeTypeSerializerSnapshot< + Long2ObjectOpenHashMap, Long2ObjectOpenHashMapSerializer> { + + private static final int CURRENT_VERSION = 1; + + public Long2ObjectOpenHashMapSnapshot() { + super(Long2ObjectOpenHashMapSerializer.class); + } + + public Long2ObjectOpenHashMapSnapshot(Long2ObjectOpenHashMapSerializer serializer) { + super(serializer); + } + + @Override + protected int getCurrentOuterSnapshotVersion() { + return CURRENT_VERSION; + } + + @Override + protected TypeSerializer[] getNestedSerializers( + Long2ObjectOpenHashMapSerializer outerSerializer) { + return new TypeSerializer[] {outerSerializer.elementSerializer}; + } + + @Override + protected Long2ObjectOpenHashMapSerializer createOuterSerializerWithNestedSerializers( + TypeSerializer[] nestedSerializers) { + TypeSerializer elementSerializer = (TypeSerializer) nestedSerializers[0]; + return new Long2ObjectOpenHashMapSerializer<>(elementSerializer); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapTypeInfo.java new file mode 100644 index 000000000..d80079cfc --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/typeinfo/Long2ObjectOpenHashMapTypeInfo.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap; + +import java.util.Objects; + +/** + * TypeInformation for {@link Long2ObjectOpenHashMap}. + * + * @param The type of elements in the Long2ObjectOpenHashMap. + */ +public class Long2ObjectOpenHashMapTypeInfo extends TypeInformation> { + + private final TypeInformation elementTypeInfo; + + public Long2ObjectOpenHashMapTypeInfo(TypeInformation elementTypeInfo) { + this.elementTypeInfo = elementTypeInfo; + } + + public TypeInformation getElementTypeInfo() { + return elementTypeInfo; + } + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 1; + } + + @Override + public int getTotalFields() { + return 1; + } + + @Override + public Class> getTypeClass() { + return (Class) Long2ObjectOpenHashMap.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer> createSerializer(ExecutionConfig config) { + return new Long2ObjectOpenHashMapSerializer<>(elementTypeInfo.createSerializer(config)); + } + + @Override + public String toString() { + return "Long2ObjectOpenHashMap Type"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + Long2ObjectOpenHashMapTypeInfo that = (Long2ObjectOpenHashMapTypeInfo) obj; + return Objects.equals(elementTypeInfo, that.elementTypeInfo); + } + + @Override + public int hashCode() { + return Objects.hash(elementTypeInfo != null ? elementTypeInfo.hashCode() : 0); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof Long2ObjectOpenHashMapTypeInfo; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java new file mode 100644 index 000000000..dbb4dd3ce --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.updater; + +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A model updater that could be used to update and retrieve model data. + * + *

Note that model updater should also ensure that model data is robust to failures, by writing + * model data to snapshots. + * + * @param data type of model. + */ +public interface ModelUpdater extends Serializable { + /** Applies the push to update the model data, e.g., using gradient to update model. */ + void update(long[] keys, double[] values); + + /** Retrieves the model data of the given keys. */ + double[] get(long[] keys); + + /** + * Returns model segments. The model segments are continuously updated/retrieved by + * push/pull(i.e., {@link ModelUpdater#update(long[], double[])} and {@link + * ModelUpdater#get(long[])}). + */ + Iterator getModelSegments(); + + /** Recovers the model data from state. */ + void initializeState(StateInitializationContext context) throws Exception; + + /** Snapshots the model data to state. */ + void snapshotState(StateSnapshotContext context) throws Exception; +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/ProxySideOutput.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/ProxySideOutput.java new file mode 100644 index 000000000..9cba95d0f --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/ProxySideOutput.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.utils; + +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +/** A collector that can only output using {@link OutputTag}. */ +public final class ProxySideOutput { + private final Output output; + + public ProxySideOutput(Output output) { + this.output = output; + } + + public void output(OutputTag outputTag, StreamRecord record) { + Preconditions.checkNotNull(outputTag); + output.collect(outputTag, record); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/TrainingUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/TrainingUtils.java new file mode 100644 index 000000000..261dba6a8 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/utils/TrainingUtils.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.utils; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.ps.Message; +import org.apache.flink.ml.common.ps.ServerOperator; +import org.apache.flink.ml.common.ps.WorkerOperator; +import org.apache.flink.ml.common.ps.iterations.IterationStageList; +import org.apache.flink.ml.common.ps.iterations.MLSession; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.util.OutputTag; + +import java.util.ArrayList; +import java.util.List; + +/** Utility function to describe iterative training process. */ +public final class TrainingUtils { + /** + * Executes the iterative machine learning logic described in {@link IterationStageList} and + * returns the fitted model data as well as the outputs from worker operator. The outputs from + * worker operator are specified via {@link MLSession#getOutputTags()}. + * + * @param inputData the input data. + * @param iterationStages the iterative processing logic. + * @param modelDataType output type information of model data. + * @param modelUpdater the logic to update model on servers. + * @param numServers number of servers. + * @return the fitted model data as well as the outputs from worker operator. The orders are + * {modelData, sideOutputs from workers}. Note that the outputs from workers shares the same + * order with the {@link MLSession#getOutputTags()}. + * @param

type information of input data. + * @param type information of the output model data. + */ + public static DataStreamList train( + DataStream
inputData, + IterationStageList iterationStages, + TypeInformation modelDataType, + ModelUpdater modelUpdater, + int numServers) { + DataStream variableStream = + inputData.getExecutionEnvironment().fromElements(new byte[0]).filter(x -> false); + + return Iterations.iterateBoundedStreamsUntilTermination( + DataStreamList.of(variableStream), + ReplayableDataStreamList.notReplay(inputData), + IterationConfig.newBuilder().build(), + new TrainIterationBody<>(modelUpdater, modelDataType, iterationStages, numServers)); + } + + /** The iteration implementation for training process. */ + private static class TrainIterationBody implements IterationBody { + private final ModelUpdater modelUpdater; + private final TypeInformation modelType; + private final IterationStageList iterationStages; + private final int numServers; + + public TrainIterationBody( + ModelUpdater modelUpdater, + TypeInformation modelType, + IterationStageList iterationStages, + int numServers) { + this.iterationStages = iterationStages; + this.modelType = modelType; + this.modelUpdater = modelUpdater; + this.numServers = numServers; + } + + @Override + @SuppressWarnings("unchecked") + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream variableStream = variableStreams.get(0); + DataStream trainData = dataStreams.get(0); + final OutputTag modelDataOutputTag = new OutputTag<>("MODEL_OUTPUT", modelType); + + SingleOutputStreamOperator messageToServer = + trainData + .connect(variableStream) + .transform( + "WorkerOp", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO, + new WorkerOperator(iterationStages, numServers)); + int numWorkers = messageToServer.getParallelism(); + + SingleOutputStreamOperator messageToWorker = + messageToServer + .partitionCustom( + (Partitioner) + (key, numPartitions) -> key % numPartitions, + (KeySelector) + value -> new Message(value).getServerId()) + .transform( + "ServerOp", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO, + new ServerOperator<>( + iterationStages.stageList, + numWorkers, + modelUpdater, + modelDataOutputTag)); + messageToWorker.setParallelism(numServers); + + DataStream feedback = + messageToWorker + .partitionCustom( + (Partitioner) + (key, numPartitions) -> key % numPartitions, + (KeySelector) + value -> new Message(value).getWorkerId()) + .map( + (MapFunction) message -> message, + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO) + .setParallelism(numWorkers); + + DataStream model = messageToWorker.getSideOutput(modelDataOutputTag); + + List> result = new ArrayList<>(); + result.add(model); + + List> sideOutputTags = iterationStages.session.getOutputTags(); + if (sideOutputTags != null) { + for (OutputTag outputTag : sideOutputTags) { + result.add(messageToServer.getSideOutput(outputTag)); + } + } + + return new IterationBodyResult( + DataStreamList.of(feedback), new DataStreamList(result), null); + } + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/MessageTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/MessageTest.java new file mode 100644 index 000000000..54a2412cc --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/MessageTest.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Iterator; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link org.apache.flink.ml.common.ps.Message}. */ +public class MessageTest { + private Message messageFromBytes; + private Message messageFromArray; + private Message messageFromPojo; + + private TypeSerializer mockPojoTypeSerializer; + + @Before + public void before() throws IOException { + messageFromArray = new Message(1, 0, 1, new long[] {1, 2}, new double[] {1, 2, 3, 4}); + messageFromBytes = new Message(messageFromArray.bytes.clone()); + mockPojoTypeSerializer = Types.POJO(MockPojo.class).createSerializer(new ExecutionConfig()); + messageFromPojo = + new Message( + 1, + 0, + 1, + new long[] {1, 2}, + new MockPojo[] {new MockPojo(1, 1), new MockPojo(2, 2)}, + mockPojoTypeSerializer); + } + + @Test + public void getKeys() { + long[] expectedKeys = new long[] {1, 2}; + assertArrayEquals(expectedKeys, messageFromArray.getKeys()); + assertArrayEquals(expectedKeys, messageFromBytes.getKeys()); + assertArrayEquals(expectedKeys, messageFromPojo.getKeys()); + } + + @Test + public void getValuesInDoubleArray() { + double[] expectedDoubleArray = new double[] {1, 2, 3, 4}; + assertArrayEquals(expectedDoubleArray, messageFromArray.getValuesInDoubleArray(), 1e-7); + assertArrayEquals(expectedDoubleArray, messageFromBytes.getValuesInDoubleArray(), 1e-7); + } + + @Test + public void getValues() throws IOException { + MockPojo[] expectedPojos = new MockPojo[] {new MockPojo(1, 1), new MockPojo(2, 2)}; + assertArrayEquals(expectedPojos, messageFromPojo.getValues(mockPojoTypeSerializer)); + } + + @Test + public void getWorkerId() { + int expectedWorkerId = 1; + assertEquals(expectedWorkerId, messageFromArray.getWorkerId()); + assertEquals(expectedWorkerId, messageFromBytes.getWorkerId()); + assertEquals(expectedWorkerId, messageFromPojo.getWorkerId()); + } + + @Test + public void setWorkerId() { + messageFromArray.setWorkerId(2); + messageFromBytes.setWorkerId(2); + messageFromPojo.setWorkerId(2); + int expectedWorkerId = 2; + assertEquals(expectedWorkerId, messageFromArray.getWorkerId()); + assertEquals(expectedWorkerId, messageFromBytes.getWorkerId()); + assertEquals(expectedWorkerId, messageFromPojo.getWorkerId()); + } + + @Test + public void getServerId() { + int expectedServerId = 0; + assertEquals(expectedServerId, messageFromArray.getServerId()); + assertEquals(expectedServerId, messageFromBytes.getServerId()); + assertEquals(expectedServerId, messageFromPojo.getServerId()); + } + + @Test + public void setServerId() { + messageFromArray.setServerId(2); + messageFromBytes.setServerId(2); + messageFromPojo.setServerId(2); + int expectedServerId = 2; + assertEquals(expectedServerId, messageFromArray.getServerId()); + assertEquals(expectedServerId, messageFromBytes.getServerId()); + assertEquals(expectedServerId, messageFromPojo.getServerId()); + } + + @Test + public void getStagedId() { + int expectedStageId = 1; + assertEquals(expectedStageId, messageFromArray.getStageId()); + assertEquals(expectedStageId, messageFromBytes.getStageId()); + assertEquals(expectedStageId, messageFromPojo.getStageId()); + } + + @Test + public void assembleMessages() { + int numServers = 4; + Message[] messages = new Message[numServers]; + for (int i = 0; i < numServers; i++) { + messages[i] = + new Message( + 1, + i, + 0, + new long[] {i * 2, i * 2 + 1}, + new double[] {i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3}); + } + + Iterator bytes = Arrays.stream(messages).map(x -> x.bytes).iterator(); + Message assembledMessage = Message.assembleMessages(bytes); + + assertEquals(-1, assembledMessage.getServerId()); + assertEquals(1, assembledMessage.getWorkerId()); + assertEquals(0, assembledMessage.getStageId()); + + long[] expectedKeys = new long[numServers * 2]; + for (int i = 0; i < expectedKeys.length; i++) { + expectedKeys[i] = i; + } + assertArrayEquals(expectedKeys, assembledMessage.getKeys()); + + double[] expectedValues = new double[numServers * 4]; + for (int i = 0; i < expectedValues.length; i++) { + expectedValues[i] = i; + } + assertArrayEquals(expectedValues, assembledMessage.getValuesInDoubleArray(), 1e-7); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/MockPojo.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/MockPojo.java new file mode 100644 index 000000000..6666e3d78 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/MockPojo.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +/** Mock pojo class to test all reduce. */ +public class MockPojo { + public int i; + public int j; + + public MockPojo(int i, int j) { + this.i = i; + this.j = j; + } + + public MockPojo() {} + + @Override + public String toString() { + return i + "-" + j; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof MockPojo) { + MockPojo other = (MockPojo) obj; + return i == other.i && j == other.j; + } + return false; + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/TrainingUtilsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/TrainingUtilsTest.java new file mode 100644 index 000000000..5ee1965d3 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/TrainingUtilsTest.java @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.iterations.AllReduceStage; +import org.apache.flink.ml.common.ps.iterations.IterationStageList; +import org.apache.flink.ml.common.ps.iterations.MLSessionImpl; +import org.apache.flink.ml.common.ps.iterations.ProcessStage; +import org.apache.flink.ml.common.ps.iterations.PullStage; +import org.apache.flink.ml.common.ps.iterations.PushStage; +import org.apache.flink.ml.common.ps.iterations.ReduceScatterStage; +import org.apache.flink.ml.common.ps.sarray.SharedDoubleArray; +import org.apache.flink.ml.common.ps.sarray.SharedLongArray; +import org.apache.flink.ml.common.ps.typeinfo.Long2ObjectOpenHashMapTypeInfo; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; +import org.apache.flink.ml.common.ps.utils.ProxySideOutput; +import org.apache.flink.ml.common.ps.utils.TrainingUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.test.util.TestBaseUtils; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableSupplier; + +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap; +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.function.Supplier; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link TrainingUtils}. */ +public class TrainingUtilsTest { + private static final int NUM_WORKERS = 2; + private static final int NUM_SERVERS = 6; + private static final int MAX_ITER = 3; + private static final int NUM_DOUBLES_PER_KEY = 2; + private DataStream inputData; + StreamExecutionEnvironment env; + + @Before + public void before() { + env = TestUtils.getExecutionEnvironment(); + env.setParallelism(NUM_WORKERS); + inputData = + env.fromCollection( + Arrays.asList( + Vectors.dense(1, 1, 1, 1), + Vectors.dense(2, 2, 2, 2), + Vectors.dense(3, 3, 3, 3), + Vectors.dense(4, 4, 4, 4))) + .map(x -> x, DenseVectorTypeInfo.INSTANCE); + } + + @Test + public void testPushSumAndPullAgg() throws Exception { + MockSession mockSession = new MockSession(); + + IterationStageList stageList = + new IterationStageList<>(mockSession) + .addStage( + new PushStage( + () -> new SharedLongArray(new long[] {1, 4}), + () -> new SharedDoubleArray(new double[] {1, 1, 4, 4}))) + .addStage( + new PullStage( + () -> new SharedLongArray(new long[] {1, 3, 4}), + () -> { + mockSession.pullResult.size(4); + return mockSession.pullResult; + }, + new MockAggregator())) + .addStage( + new ResultChecker( + () -> { + double[] expectedResult = new double[4]; + Arrays.fill( + expectedResult, + (mockSession.iterationId + 1) + * (mockSession.iterationId + 1) + * 68); + return Arrays.equals( + expectedResult, + trimToArray(mockSession.pullResult)); + })) + .setTerminationCriteria(session -> session.iterationId >= MAX_ITER); + + DataStreamList resultList = + TrainingUtils.train( + inputData, + stageList, + new TupleTypeInfo<>( + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new MockModelUpdater(NUM_DOUBLES_PER_KEY), + NUM_SERVERS); + + DataStream> modelStream = resultList.get(0); + List> collectedModelData = + IteratorUtils.toList(modelStream.executeAndCollect()); + List> expectedModelData = + Arrays.asList( + Tuple2.of( + 1L, new double[] {NUM_WORKERS * MAX_ITER, NUM_WORKERS * MAX_ITER}), + Tuple2.of(3L, new double[] {0, 0}), + Tuple2.of( + 4L, + new double[] { + NUM_WORKERS * MAX_ITER * 4, NUM_WORKERS * MAX_ITER * 4 + })); + + verifyModelData(expectedModelData, collectedModelData); + } + + @Test + public void testPushMinAndPull() throws Exception { + MockSession mockSession = new MockSession(); + + IterationStageList stageList = + new IterationStageList<>(mockSession) + .addStage( + new PushStage( + () -> new SharedLongArray(new long[] {1, 4}), + () -> new SharedDoubleArray(new double[] {1, 1, 4, 4}), + Double::min)) + .addStage( + new PullStage( + () -> new SharedLongArray(new long[] {1, 3}), + () -> { + mockSession.pullResult.size(4); + return mockSession.pullResult; + })) + .addStage( + new ResultChecker( + () -> + Arrays.equals( + new double[] { + mockSession.iterationId + 1, + mockSession.iterationId + 1, + 0, + 0 + }, + trimToArray(mockSession.pullResult)))) + .setTerminationCriteria(session -> session.iterationId >= MAX_ITER); + + DataStreamList resultList = + TrainingUtils.train( + inputData, + stageList, + new TupleTypeInfo<>( + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new MockModelUpdater(NUM_DOUBLES_PER_KEY), + NUM_SERVERS); + DataStream> modelStream = resultList.get(0); + List> collectedModelData = + IteratorUtils.toList(modelStream.executeAndCollect()); + List> expectedModelData = + Arrays.asList( + Tuple2.of(1L, new double[] {MAX_ITER, MAX_ITER}), + Tuple2.of(3L, new double[] {0, 0}), + Tuple2.of(4L, new double[] {MAX_ITER * 4, MAX_ITER * 4})); + + verifyModelData(expectedModelData, collectedModelData); + } + + @Test + public void testAllReduce() throws Exception { + ExecutionConfig executionConfig = inputData.getExecutionEnvironment().getConfig(); + int executionInterval = 2; + TypeSerializer mockPojoTypeSerializer = + Types.POJO(MockPojo.class).createSerializer(executionConfig); + MockSession mockSession = new MockSession(); + + IterationStageList stageList = + new IterationStageList<>(mockSession) + .addStage(new MockInitStage()) + .addStage( + new AllReduceStage<>( + () -> mockSession.allReduceInputAndResult, + () -> mockSession.allReduceInputAndResult, + (ReduceFunction) TrainingUtilsTest::sumPojo, + mockPojoTypeSerializer, + executionInterval)) + .addStage( + new ResultChecker( + () -> { + if (mockSession.iterationId % executionInterval == 0) { + MockPojo[] reduceResult = + mockSession.allReduceInputAndResult; + Assert.assertEquals(2, reduceResult.length); + MockPojo expectedPojo = + new MockPojo( + NUM_WORKERS + * (mockSession.iterationId + / executionInterval + + 1), + NUM_WORKERS + * (mockSession.iterationId + / executionInterval + + 1) + * 2); + Assert.assertEquals(expectedPojo, reduceResult[0]); + Assert.assertEquals(expectedPojo, reduceResult[1]); + } + return true; + })) + .setTerminationCriteria(session -> session.iterationId >= MAX_ITER); + + DataStreamList resultList = + TrainingUtils.train( + inputData, + stageList, + new TupleTypeInfo<>( + Types.LONG, + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new MockModelUpdater(NUM_DOUBLES_PER_KEY), + NUM_SERVERS); + DataStream> modelStream = resultList.get(0); + List> modelData = + IteratorUtils.toList(modelStream.executeAndCollect()); + Assert.assertEquals(0, modelData.size()); + } + + @Test + public void testReduceScatter() throws Exception { + ExecutionConfig executionConfig = inputData.getExecutionEnvironment().getConfig(); + int executionInterval = 2; + TypeSerializer mockPojoTypeSerializer = + Types.POJO(MockPojo.class).createSerializer(executionConfig); + MockSession mockSession = + new MockSession( + Collections.singletonList( + new OutputTag<>( + "reduceScatter", + new TupleTypeInfo<>( + Types.INT, + Types.INT, + Types.OBJECT_ARRAY(Types.POJO(MockPojo.class)))))); + + IterationStageList stageList = + new IterationStageList<>(mockSession) + .addStage(new MockInitStage()) + .addStage( + new ReduceScatterStage<>( + () -> mockSession.reduceScatterInput, + () -> mockSession.reduceScatterResult, + new int[] {1, 1}, + (ReduceFunction) TrainingUtilsTest::sumPojo, + mockPojoTypeSerializer, + executionInterval)) + .addStage( + new ResultChecker( + () -> { + if (mockSession.iterationId % executionInterval == 0) { + MockPojo[] reduceResult = + mockSession.reduceScatterResult; + Assert.assertEquals(1, reduceResult.length); + MockPojo expectedPojo = + new MockPojo(NUM_WORKERS, NUM_WORKERS * 2); + Assert.assertEquals(expectedPojo, reduceResult[0]); + } + return true; + })) + .setTerminationCriteria(session -> session.iterationId >= MAX_ITER); + + DataStreamList resultList = + TrainingUtils.train( + inputData, + stageList, + new TupleTypeInfo<>( + Types.LONG, + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new MockModelUpdater(NUM_DOUBLES_PER_KEY), + NUM_SERVERS); + DataStream> modelStream = resultList.get(0); + List> modelData = + IteratorUtils.toList(modelStream.executeAndCollect()); + Assert.assertEquals(0, modelData.size()); + } + + @Test + public void readTrainDataAndOutput() throws Exception { + MockSession mockSession = + new MockSession( + Collections.singletonList( + new OutputTag<>( + "numOfTrainData", + new TupleTypeInfo<>(Types.INT, Types.INT, Types.INT)))); + + IterationStageList stageList = + new IterationStageList<>(mockSession) + .addStage(new ReadDataStage()) + .addStage( + new AllReduceStage<>( + () -> mockSession.numDataScanned, + () -> mockSession.numDataScanned, + TrainingUtilsTest::sumIntArray, + IntSerializer.INSTANCE)) + .addStage(new MockOutputStage<>(() -> mockSession.numDataScanned[0])) + .setTerminationCriteria(session -> session.iterationId >= MAX_ITER); + + DataStreamList resultList = + TrainingUtils.train( + inputData, + stageList, + new TupleTypeInfo<>( + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new MockModelUpdater(NUM_DOUBLES_PER_KEY), + NUM_SERVERS); + + DataStream> pulledStream = resultList.get(1); + List> pulls = + IteratorUtils.toList(pulledStream.executeAndCollect()); + + List> expectedPulls = new ArrayList<>(); + int numDataScanned = 4; + for (int i = 0; i < MAX_ITER; i++) { + for (int w = 0; w < NUM_WORKERS; w++) { + expectedPulls.add(Tuple3.of(i, w, numDataScanned)); + } + } + Comparator> comparator = + (o1, o2) -> { + int cmp = Integer.compare(o1.f0, o2.f0); + if (cmp == 0) { + cmp = Integer.compare(o1.f1, o2.f1); + if (cmp == 0) { + cmp = Integer.compare(o1.f2, o2.f2); + } + } + return cmp; + }; + TestBaseUtils.compareResultCollections(expectedPulls, pulls, comparator); + } + + /** The session that one worker can access. */ + private static class MockSession extends MLSessionImpl { + public MockPojo[] allReduceInputAndResult; + public MockPojo[] reduceScatterInput; + public MockPojo[] reduceScatterResult; + public SharedDoubleArray pullResult; + private ProxySideOutput output; + private Integer[] numDataScanned; + + @Override + public void setOutput(ProxySideOutput output) { + this.output = output; + } + + public MockSession(List> outputTags) { + super(outputTags); + pullResult = new SharedDoubleArray(); + this.numDataScanned = new Integer[1]; + } + + public MockSession() { + this(null); + } + } + + /** The model updater on servers. */ + private static class MockModelUpdater implements ModelUpdater> { + private final int numDoublesPerKey; + private Long2ObjectOpenHashMap model; + private ListState> modelDataState; + + public MockModelUpdater(int numDoublesPerKey) { + this.numDoublesPerKey = numDoublesPerKey; + this.model = new Long2ObjectOpenHashMap<>(); + } + + @Override + public void update(long[] keys, double[] values) { + Preconditions.checkState(keys.length * numDoublesPerKey == values.length); + for (int i = 0; i < keys.length; i++) { + long index = keys[i]; + model.putIfAbsent(index, new double[numDoublesPerKey]); + double[] oneDimModel = model.get(index); + for (int j = 0; j < numDoublesPerKey; j++) { + oneDimModel[j] += values[i * numDoublesPerKey + j]; + } + } + } + + @Override + public double[] get(long[] keys) { + double[] values = new double[keys.length * numDoublesPerKey]; + for (int i = 0; i < keys.length; i++) { + long index = keys[i]; + model.putIfAbsent(index, new double[numDoublesPerKey]); + double[] oneDimModel = model.get(index); + for (int j = 0; j < numDoublesPerKey; j++) { + values[i * numDoublesPerKey + j] += oneDimModel[j]; + } + } + return values; + } + + @Override + public Iterator> getModelSegments() { + return model.long2ObjectEntrySet().stream() + .map(x -> Tuple2.of(x.getLongKey(), x.getValue())) + .iterator(); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "modelDataState", + new Long2ObjectOpenHashMapTypeInfo<>( + PrimitiveArrayTypeInfo + .DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))); + model = + OperatorStateUtils.getUniqueElement(modelDataState, "modelDataState") + .orElse(new Long2ObjectOpenHashMap<>()); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + modelDataState.clear(); + modelDataState.add(model); + } + } + + /** A stage that initialize the value for all-reduce and reduce-scatter. */ + private static class MockInitStage extends ProcessStage { + + @Override + public void process(MockSession session) { + if (session.iterationId == 0) { + session.allReduceInputAndResult = new MockPojo[2]; + session.allReduceInputAndResult[0] = new MockPojo(1, 2); + session.allReduceInputAndResult[1] = new MockPojo(1, 2); + } + + session.reduceScatterInput = new MockPojo[2]; + session.reduceScatterInput[0] = new MockPojo(1, 2); + session.reduceScatterInput[1] = new MockPojo(1, 2); + session.reduceScatterResult = new MockPojo[1]; + } + } + + /** A stage that scans the data and count the number of data points scanned. */ + private static class ReadDataStage extends ProcessStage { + + @Override + public void process(MockSession session) throws Exception { + session.numDataScanned[0] = 0; + while (session.inputData.hasNext()) { + session.inputData.next(); + session.numDataScanned[0]++; + } + session.inputData.reset(); + } + } + + /** A stage that checks the value of some intermediate results. */ + private static class ResultChecker extends ProcessStage { + Supplier checker; + + public ResultChecker(SerializableSupplier checker) { + this.checker = checker; + } + + @Override + public void process(MockSession session) { + Preconditions.checkState(checker.get()); + } + } + + /** A stage that output non-model data to downstream tasks. */ + private static class MockOutputStage extends ProcessStage { + + private final SerializableSupplier outputSupplier; + + public MockOutputStage(SerializableSupplier outputSupplier) { + this.outputSupplier = outputSupplier; + } + + @Override + public void process(MockSession session) { + OutputTag> outputTag = + (OutputTag>) session.getOutputTags().get(0); + session.output.output( + outputTag, + new StreamRecord<>( + Tuple3.of( + session.iterationId, session.workerId, outputSupplier.get()))); + } + } + + /** An aggregator that can be used in a pull request. */ + private static class MockAggregator implements PullStage.Aggregator { + @Override + public double[] add(double[] in, double[] acc) { + if (acc == null) { + acc = new double[in.length * in.length]; + } + + for (int i = 0; i < in.length; i++) { + for (int j = 0; j < in.length; j++) { + acc[i * in.length + j] += in[i] * in[j]; + } + } + return acc; + } + + @Override + public double[] merge(double[] acc1, double[] acc2) { + for (int i = 0; i < acc1.length; i++) { + acc2[i] += acc1[i]; + } + return acc2; + } + } + + private void verifyModelData( + List> expected, List> actual) { + assertEquals(expected.size(), actual.size()); + expected.sort(Comparator.comparingLong(x -> x.f0)); + actual.sort(Comparator.comparingLong(x -> x.f0)); + for (int i = 0; i < expected.size(); i++) { + assertEquals(expected.get(i).f0, actual.get(i).f0); + assertArrayEquals(expected.get(i).f1, actual.get(i).f1, 1e-7); + } + } + + private static MockPojo[] sumPojo(MockPojo[] d1, MockPojo[] d2) { + Preconditions.checkArgument(d1.length == d2.length); + for (int i = 0; i < d1.length; i++) { + d2[i].i += d1[i].i; + d2[i].j += d1[i].j; + } + return d2; + } + + private static Integer[] sumIntArray(Integer[] d1, Integer[] d2) { + Preconditions.checkArgument(d1.length == d2.length); + for (int i = 0; i < d1.length; i++) { + d2[i] += d1[i]; + } + return d2; + } + + private static double[] trimToArray(SharedDoubleArray array) { + return Arrays.copyOfRange(array.elements(), 0, array.size()); + } +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java index 8de3a44d4..3beae41ca 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java @@ -44,6 +44,13 @@ public static double getDouble(byte[] b, int off) { return Double.longBitsToDouble(getLong(b, off)); } + public static int getInt(byte[] b, int off) { + return ((b[off + 3] & 0xFF)) + + ((b[off + 2] & 0xFF) << 8) + + ((b[off + 1] & 0xFF) << 16) + + ((b[off]) << 24); + } + /* * Methods for packing primitive values into byte arrays starting at given * offsets. @@ -63,4 +70,75 @@ public static void putLong(byte[] b, int off, long val) { public static void putDouble(byte[] b, int off, double val) { putLong(b, off, Double.doubleToLongBits(val)); } + + public static void putInt(byte[] b, int off, int val) { + b[off + 3] = (byte) (val); + b[off + 2] = (byte) (val >>> 8); + b[off + 1] = (byte) (val >>> 16); + b[off] = (byte) (val >>> 24); + } + + /** Gets a long array from the byte array starting from the given offset. */ + public static long[] getLongArray(byte[] bytes, int offset) { + int size = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + long[] result = new long[size]; + for (int i = 0; i < size; i++) { + result[i] = Bits.getLong(bytes, offset); + offset += Long.BYTES; + } + return result; + } + + /** + * Puts a long array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int putLongArray(long[] array, byte[] bytes, int offset) { + Bits.putInt(bytes, offset, array.length); + offset += Integer.BYTES; + for (int i = 0; i < array.length; i++) { + Bits.putLong(bytes, offset, array[i]); + offset += Long.BYTES; + } + return offset; + } + + /** Returns the size of a long array in bytes. */ + public static int getLongArraySizeInBytes(long[] array) { + return Integer.BYTES + array.length * Long.BYTES; + } + + /** Gets a double array from the byte array starting from the given offset. */ + public static double[] getDoubleArray(byte[] bytes, int offset) { + int size = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = Bits.getDouble(bytes, offset); + offset += Long.BYTES; + } + return result; + } + + /** + * Puts a double array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int putDoubleArray(double[] array, byte[] bytes, int offset) { + Bits.putInt(bytes, offset, array.length); + offset += Integer.BYTES; + for (int i = 0; i < array.length; i++) { + Bits.putDouble(bytes, offset, array[i]); + offset += Double.BYTES; + } + return offset; + } + + /** Returns the size of a double array in bytes. */ + public static int getDoubleArraySizeInBytes(double[] array) { + return Integer.BYTES + array.length * Long.BYTES; + } } diff --git a/flink-ml-uber/pom.xml b/flink-ml-uber/pom.xml index 212527d48..1d6629c9d 100644 --- a/flink-ml-uber/pom.xml +++ b/flink-ml-uber/pom.xml @@ -94,6 +94,7 @@ under the License. org.apache.flink:flink-ml-lib-${flink.main.version} org.apache.flink:flink-ml-benchmark-${flink.main.version} dev.ludovic.netlib:blas + it.unimi.dsi:fastutil