Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
130 commits
Select commit Hold shift + click to select a range
f81a9c2
Add assertions for tensor shapes.
khatchad Jul 11, 2025
6185d5a
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Jul 14, 2025
81cfcb6
Set expected tensors to the correct values.
khatchad Jul 14, 2025
c88367b
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Jul 16, 2025
464a49e
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Jul 16, 2025
84afd02
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Jul 16, 2025
3ec08f6
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Jul 16, 2025
d4d2d23
Remove the hard-coded MNIST input.
khatchad Jul 17, 2025
dd7e9ba
This comment seems incorrect.
khatchad Jul 18, 2025
cc1c191
Additional assertions.
khatchad Jul 18, 2025
30b626b
Progress.
khatchad Jul 18, 2025
96a0c4e
Progress.
khatchad Jul 18, 2025
8e10b2c
Add more shape checks.
khatchad Jul 21, 2025
cc07365
Progress.
khatchad Jul 21, 2025
6423d8d
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Jul 21, 2025
278abab
Progress.
khatchad Jul 21, 2025
955fcf1
Progress.
khatchad Jul 21, 2025
a97a751
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Jul 22, 2025
7a16f49
Add tests
khatchad Jul 22, 2025
b5e1802
Add comment.
khatchad Jul 22, 2025
cb03d9e
Add launch config.
khatchad Jul 22, 2025
3122539
Progress.
khatchad Jul 30, 2025
72337c9
Progress.
khatchad Jul 31, 2025
e3d09c4
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Aug 1, 2025
ec99b02
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Aug 1, 2025
5f4184c
Remove arguments from test.
khatchad Aug 1, 2025
759245d
Fix test.
khatchad Aug 1, 2025
b40a6cc
Cleanup.
khatchad Aug 1, 2025
7d894dd
Inline variable.
khatchad Aug 1, 2025
9baab67
Throw an exception on unknown dtypes.
khatchad Aug 1, 2025
42f7d9b
Finish the float32 dtype.
khatchad Aug 1, 2025
4a60ffd
Cleanup.
khatchad Aug 1, 2025
61e6617
Use constant for dtype in tests.
khatchad Aug 1, 2025
5047844
Add notes.
khatchad Aug 8, 2025
e9969bd
Format.
khatchad Aug 13, 2025
ad5993c
Remove TODO.
khatchad Aug 13, 2025
969319b
Move TODO.
khatchad Aug 13, 2025
fc5b67c
Inline variable.
khatchad Aug 13, 2025
24196c1
Inline variable.
khatchad Aug 13, 2025
c5f30e1
Alter log.
khatchad Aug 13, 2025
58f65f4
Add test.
khatchad Aug 13, 2025
1bbccc8
Format.
khatchad Aug 13, 2025
addd7d0
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Aug 14, 2025
b6a5e1b
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Aug 14, 2025
ba98228
Inline local variable.
khatchad Aug 14, 2025
08b9857
Only add if it's constant.
khatchad Aug 14, 2025
d6d1f6d
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Aug 18, 2025
ef7d719
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Aug 18, 2025
14d6117
Hoist variable.
khatchad Aug 18, 2025
d09c1ea
Rename variable.
khatchad Aug 18, 2025
d29c75d
Don't use wildcard in expression.
khatchad Aug 18, 2025
1db0c9c
Start `tf.constant()` inference.
khatchad Aug 18, 2025
c8e5d40
More defensive.
khatchad Aug 18, 2025
2cfa1d3
Update tests.
khatchad Aug 18, 2025
6d93066
Inline local variables.
khatchad Aug 18, 2025
54ccbb8
Test updates.
khatchad Aug 18, 2025
4ad808b
Format.
khatchad Aug 19, 2025
ad2b486
Format and elaborate comments.
khatchad Aug 19, 2025
1ee6df1
Throw an exception if there is an explicit shape argument.
khatchad Aug 19, 2025
d048019
Get rid of failures.
khatchad Aug 19, 2025
c97cf36
Turn generator processing into a class hierarchy.
khatchad Aug 19, 2025
5cc6fb8
Split comment.
khatchad Aug 19, 2025
cff454c
Factor out some common code.
khatchad Aug 19, 2025
5c933ad
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Aug 20, 2025
17e64d2
Simplify test code.
khatchad Aug 20, 2025
f1c83d0
Add assertions.
khatchad Aug 20, 2025
2623d48
Simplify test code.
khatchad Aug 20, 2025
c11f9ce
Progress.
khatchad Aug 20, 2025
fcf030b
Use doubles.
khatchad Aug 20, 2025
5ceefb8
Rename.
khatchad Aug 20, 2025
7eeed8d
Factor out common code for shapes.
khatchad Aug 20, 2025
7b2dbae
Add method "implementations."
khatchad Aug 20, 2025
1509bf4
Add tests.
khatchad Aug 20, 2025
fc2d4bc
Add file.
khatchad Aug 20, 2025
a7c09e0
Add comments.
khatchad Aug 20, 2025
6c88c94
Remove TODO.
khatchad Aug 21, 2025
1088488
Comments and variable rename.
khatchad Aug 21, 2025
bc446c8
Add value number for dtype.
khatchad Aug 21, 2025
ea5205d
Return -1 for unsupported dtype arguments.
khatchad Aug 21, 2025
e161461
Rename variables.
khatchad Aug 21, 2025
6277b7c
Add docs.
khatchad Aug 21, 2025
9437259
Extract method refactoring.
khatchad Aug 21, 2025
ea949e7
Pull up method refactoring.
khatchad Aug 21, 2025
ff68dae
Extract constant refactoring.
khatchad Aug 21, 2025
0cf87cd
Format.
khatchad Aug 21, 2025
772319f
Fix https://github.com/wala/ML/issues/298.
khatchad Aug 21, 2025
a0277ee
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Aug 21, 2025
fc152a6
Encapsulate field refactoring.
khatchad Aug 26, 2025
c60f50a
Add assertions.
khatchad Aug 26, 2025
9060c0c
Shorten.
khatchad Aug 26, 2025
a1d1b9c
Fix comment.
khatchad Aug 26, 2025
613fbf5
Pull-up method refactoring.
khatchad Aug 26, 2025
2e1c5e8
Sort members refactoring.
khatchad Aug 26, 2025
8b72fe5
Add docs.
khatchad Aug 26, 2025
b38d460
Extract variable refactoring.
khatchad Aug 26, 2025
b2f814f
Progress.
khatchad Aug 26, 2025
949eca2
Add comment.
khatchad Aug 26, 2025
1210fbf
New test.
khatchad Aug 28, 2025
8a05bc5
New line.
khatchad Aug 28, 2025
fc13736
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Aug 28, 2025
d7c7b20
Format.
khatchad Aug 28, 2025
f45d1d4
Rename method.
khatchad Aug 28, 2025
8cd367e
Fix test.
khatchad Aug 28, 2025
4450b92
Revert "Rename method."
khatchad Aug 28, 2025
c92b5b2
Add metadata.
khatchad Sep 5, 2025
f3c63ce
New test.
khatchad Sep 5, 2025
fc8f4ba
Cleanup.
khatchad Sep 5, 2025
3a332c1
New test.
khatchad Sep 5, 2025
b91b2f1
Progress.
khatchad Sep 5, 2025
ada8ff6
Add test file.
khatchad Sep 5, 2025
7c28fdd
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Sep 8, 2025
1781fbc
Black.
khatchad Sep 8, 2025
4e46879
Add asserts.
khatchad Sep 8, 2025
c1db8f2
Add another test case.
khatchad Sep 8, 2025
02acac2
Simplify.
khatchad Sep 8, 2025
be01c4d
Progress.
khatchad Sep 8, 2025
cf80b59
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Sep 9, 2025
db020e0
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Sep 9, 2025
ecf7af9
New test.
khatchad Sep 9, 2025
86770a3
Add doc and change method names.
khatchad Sep 9, 2025
1b88e96
More tests.
khatchad Sep 9, 2025
288ded8
Remove test file.
khatchad Sep 9, 2025
911277e
Rename.
khatchad Sep 11, 2025
09a479c
Separate.
khatchad Sep 11, 2025
74b0ef4
Use API.
khatchad Sep 11, 2025
6b098bd
Add dtype inference.
khatchad Sep 11, 2025
13c7ec0
Update tests.
khatchad Sep 11, 2025
216bb9a
Add test for varying dtypes.
khatchad Sep 11, 2025
515fa5b
Add tests for https://github.com/wala/ML/issues/308.
khatchad Sep 12, 2025
d84b52a
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always…
khatchad Sep 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@
<putfield class="LRoot" field="sparse_tensor" fieldType="LRoot" ref="framework" value="sparse_tensor" />
<new def="constant_op" class="Lobject" />
<putfield class="LRoot" field="constant_op" fieldType="LRoot" ref="framework" value="constant_op" />
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/dtypes -->
<new def="dtypes" class="Lobject" />
<putfield class="LRoot" field="dtypes" fieldType="LRoot" ref="framework" value="dtypes" />
<new def="ragged_factory_ops" class="Lobject" />
<putfield class="LRoot" field="ragged_factory_ops" fieldType="LRoot" ref="ragged" value="ragged_factory_ops" />
<new def="ragged_math_ops" class="Lobject" />
Expand Down Expand Up @@ -218,6 +221,10 @@
<new def="constant" class="Ltensorflow/functions/constant" />
<putfield class="LRoot" field="constant" fieldType="LRoot" ref="x" value="constant" />
<putfield class="LRoot" field="constant" fieldType="LRoot" ref="constant_op" value="constant" />
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/dtypes#float32 -->
<new def="float32" class="Ltensorflow/dtypes/DType" />
<putfield class="LRoot" field="float32" fieldType="LRoot" ref="x" value="float32" />
<putfield class="LRoot" field="float32" fieldType="LRoot" ref="dtypes" value="float32" />
<new def="ragged_constant" class="Ltensorflow/functions/ragged_constant" />
<putfield class="LRoot" field="constant" fieldType="LRoot" ref="ragged" value="ragged_constant" />
<putfield class="LRoot" field="constant" fieldType="LRoot" ref="ragged_factory_ops" value="ragged_constant" />
Expand Down Expand Up @@ -315,6 +322,17 @@
</method>
</class>
</package>
<package name="tensorflow/dtypes">
<class name="DType" allocatable="true">
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/dtypes/DType -->
<method name="do" descriptor="()LRoot;" numArgs="3" paramNames="self type_enum handle_data">
<new def="obj" class="Ltensorflow//DType" />
<putfield class="LRoot" field="type_enum" fieldType="LRoot" ref="obj" value="type_enum" />
<putfield class="LRoot" field="handle_data" fieldType="LRoot" ref="obj" value="handle_data" />
<return value="obj" />
</method>
</class>
</package>
<package name="tensorflow/objects">
<class name="feature" allocatable="true" />
</package>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,11 @@ public String toString() {
};
}

private final Map<PointsToSetVariable, TensorType> init;
private final Map<PointsToSetVariable, Set<TensorType>> init;

public TensorTypeAnalysis(
Graph<PointsToSetVariable> G,
Map<PointsToSetVariable, TensorType> init,
Map<PointsToSetVariable, Set<TensorType>> init,
Map<PointsToSetVariable, TensorType> reshapeTypes,
Map<PointsToSetVariable, TensorType> set_shapes,
Set<PointsToSetVariable> conv2ds,
Expand Down Expand Up @@ -480,7 +480,8 @@ protected TensorVariable[] makeStmtRHS(int size) {
protected void initializeVariables() {
super.initializeVariables();
for (PointsToSetVariable src : init.keySet()) {
getOut(src).state.add(init.get(src));
Set<TensorType> tensorTypes = init.get(src);
getOut(src).state.addAll(tensorTypes);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.ibm.wala.cast.python.ml.client;

import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType;
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;

/**
* Represents a call to the <code>constant()</code> function in TensorFlow.
*
* @see <a href="https://www.tensorflow.org/api_docs/python/tf/constant">constant()</a>.
* @author <a href="mailto:khatchad@hunter.cuny.edu">Raffi Khatchadourian</a>
*/
public class Constant extends TensorGenerator {

private static final int VALUE_NUMBER_FOR_VALUE_ARGUMENT = 2;

private static final int VALUE_NUMBER_FOR_DTYPE_ARGUMENT = 3;

private static final int VALUE_NUMBER_FOR_SHAPE_ARGUMENT = 4;

public Constant(PointsToSetVariable source, CGNode node) {
super(source, node);
}

@Override
protected Set<List<Dimension<?>>> getDefaultShapes(PropagationCallGraphBuilder builder) {
// If the shape argument is not specified, then the shape is inferred from the shape of value.
// TODO: Handle keyword arguments.
return getShapes(builder, this.getValueNumberForValueArgument());
}

@Override
protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
// If the dtype argument is not specified, then the type is inferred from the type of value.
// TODO: Handle keyword arguments.
return getDTypes(builder, this.getValueNumberForValueArgument());
}

@Override
protected int getValueNumberForDTypeArgument() {
return VALUE_NUMBER_FOR_DTYPE_ARGUMENT;
}

protected int getValueNumberForValueArgument() {
return VALUE_NUMBER_FOR_VALUE_ARGUMENT;
}

@Override
protected int getValueNumberForShapeArgument() {
// Shapes can also be specified as an explicit argument. Here, we examine the third explicit
// argument (recall that the first argument is implicit and corresponds to the called
// function's name).
return VALUE_NUMBER_FOR_SHAPE_ARGUMENT;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.ibm.wala.cast.python.ml.client;

import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32;

import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType;
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;

/**
* A generator for tensors created by the `ones()` function in TensorFlow.
*
* @see <a href="https://www.tensorflow.org/api_docs/python/tf/ones">TensorFlow ones() API</a>.
* @author <a href="mailto:khatchad@hunter.cuny.edu">Raffi Khatchadourian</a>
*/
public class Ones extends TensorGenerator {

private static final int VALUE_NUMBER_FOR_SHAPE_ARGUMENT = 2;

private static final int VALUE_NUMBER_FOR_DTYPE_ARGUMENT = 3;

public Ones(PointsToSetVariable source, CGNode node) {
super(source, node);
}

@Override
protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
LOGGER.info(
"No dtype specified for source: " + source + ". Using default dtype of: " + FLOAT32 + " .");

// Use the default dtype of float32.
return EnumSet.of(FLOAT32);
}

@Override
protected Set<List<Dimension<?>>> getDefaultShapes(PropagationCallGraphBuilder builder) {
throw new UnsupportedOperationException(
"Shapes for ones() are mandatory and must be provided explicitly.");
}

@Override
protected int getValueNumberForShapeArgument() {
return VALUE_NUMBER_FOR_SHAPE_ARGUMENT; // The shape is in the first explicit argument.
}

@Override
protected int getValueNumberForDTypeArgument() {
return VALUE_NUMBER_FOR_DTYPE_ARGUMENT; // The dtype is in the second explicit argument.
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -667,10 +667,9 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
Set<PointsToSetVariable> sources =
getDataflowSources(dataflow, builder.getCallGraph(), builder.getPointerAnalysis());

TensorType mnistData = TensorType.mnistInput();
Map<PointsToSetVariable, TensorType> init = HashMapFactory.make();
Map<PointsToSetVariable, Set<TensorType>> init = HashMapFactory.make();

for (PointsToSetVariable v : sources) init.put(v, mnistData);
for (PointsToSetVariable v : sources) init.put(v, getTensorTypes(v, builder));

Map<PointsToSetVariable, TensorType> placeholders = null;
try {
Expand All @@ -681,7 +680,7 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
logger.fine("Placeholders: " + placeholders);

for (Map.Entry<PointsToSetVariable, TensorType> e : placeholders.entrySet())
init.put(e.getKey(), e.getValue());
init.put(e.getKey(), Set.of(e.getValue()));

Map<PointsToSetVariable, TensorType> setCalls = HashMapFactory.make();
Map<PointsToSetVariable, TensorType> set_shapes = getShapeSourceCalls(set_shape, builder, 1);
Expand Down Expand Up @@ -722,6 +721,24 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
return tt;
}

/**
* Returns the set of possible {@link TensorType}s that the given {@link PointsToSetVariable} can
* take on.
*
* @param source The dataflow source to analyze.
* @param builder The {@link PropagationCallGraphBuilder} used to build the call graph and pointer
* analysis.
* @return A set of {@link TensorType}s that the given {@link PointsToSetVariable} can take on.
* Empty set is returned if the possible tensor types cannot be determined.
*/
private Set<TensorType> getTensorTypes(
PointsToSetVariable source, PropagationCallGraphBuilder builder) {
logger.info("Getting tensor types for source: " + source + ".");

TensorGenerator generator = TensorGeneratorFactory.getGenerator(source);
return generator.getTensorTypes(builder);
}

private Map<PointsToSetVariable, TensorType> handleShapeSourceOp(
PropagationCallGraphBuilder builder,
Graph<PointsToSetVariable> dataflow,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package com.ibm.wala.cast.python.ml.client;

import static java.util.function.Function.identity;

import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType;
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.propagation.ConstantKey;
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
import com.ibm.wala.util.collections.HashSetFactory;
import com.ibm.wala.util.debug.UnimplementedError;
import com.ibm.wala.util.intset.OrdinalSet;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.StreamSupport;

/**
* A representation of the TensorFlow range operation.
*
* <p>This class is used to generate a tensor that contains a sequence of numbers, similar to the
* range function in Python.
*
* @see <a href="https://www.tensorflow.org/api_docs/python/tf/range">TensorFlow range
* documentation</a>.
* @author <a href="mailto:khatchad@hunter.cuny.edu">Raffi Khatchadourian</a>
*/
public class Range extends TensorGenerator {

public Range(PointsToSetVariable source, CGNode node) {
super(source, node);
}

@Override
protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder) {
Set<List<Dimension<?>>> ret = HashSetFactory.make();
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();

// The shape of a range tensor is always a 1D tensor with the length equal to the number of
// elements in the range.
// For example, `tf.range(5)` produces a tensor with shape (5,).

double start = 0; // Default start value.
double limit = start; // Default limit value.
double delta = 1; // Default step value.

// There are two versions of the `range` function:
// 1. `tf.range(limit)` - generates a range from 0 to limit
// 2. `tf.range(start, limit, delta)` - generates a range from start to limit with a step of
// delta.

// First, decide which version of the `range` function is being called based on the number of
// numeric arguments.j
// TODO: Handle keyword arguments.

int numOfNumericPositionalArgs = getNumberOfNumericPositionalArgs(pointerAnalysis);

if (numOfNumericPositionalArgs == 1) {
// it must *just* be `limit`.
PointerKey limitPK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, 2);
OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK);

assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit.";

for (InstanceKey limitIK : limitPointsToSet)
if (limitIK instanceof ConstantKey) {
limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue();
int shape = (int) Math.ceil((limit - start) / delta);
ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor.
} else
throw new IllegalStateException(
"Expected a " + ConstantKey.class + " for limit, but got: " + limitIK + ".");
} else
// TODO: Handle more cases.
throw new UnimplementedError(
"Currently cannot handle more than one numeric positional argument for range().");

return ret;
}

private int getNumberOfNumericPositionalArgs(PointerAnalysis<InstanceKey> pointerAnalysis) {
int ret = 0;
int explicitArgumentIndex = 2; // Start from the first explicit argument.

while (true) {
PointerKey pk =
pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, explicitArgumentIndex);
OrdinalSet<InstanceKey> pointsToSet = pointerAnalysis.getPointsToSet(pk);

if (pointsToSet.isEmpty()) break; // End of positional arguments.

// Check if the pointsToSet contains numeric values.
boolean allNumeric =
StreamSupport.stream(pointsToSet.spliterator(), false)
.filter(ik -> ik instanceof ConstantKey)
.map(ik -> (ConstantKey<?>) ik)
.map(ConstantKey::getValue)
.allMatch(v -> v instanceof Number); // Check if all values are numeric.

if (!allNumeric) break; // There's some argument that is not numeric for this argument.

ret++; // Increment the count of numeric positional arguments.
explicitArgumentIndex++; // Move to the next explicit argument.
}

return ret;
}

@Override
protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
// The dtype of the resulting tensor is inferred from the inputs unless it is provided
// explicitly.

// TODO: Handle keyword arguments.
int numberOfNumericPositionalArgs =
getNumberOfNumericPositionalArgs(builder.getPointerAnalysis());

EnumSet<DType> types =
IntStream.range(0, numberOfNumericPositionalArgs)
.map(i -> i + 2) // Positional arguments start at index 2.
.mapToObj(val -> getDTypes(builder, val).stream())
.flatMap(identity())
.distinct()
.collect(Collectors.toCollection(() -> EnumSet.noneOf(DType.class)));

// FIXME: We can't tell the difference here between varying dtypes in a single call and that of
// possible varying dtypes values from the points-to graph. Below, we are treating it as these
// values lie in a single call, but that may not be the case.

if (types.contains(DType.FLOAT64)) return EnumSet.of(DType.FLOAT64);
else if (types.contains(DType.FLOAT32)) return EnumSet.of(DType.FLOAT32);
else if (types.contains(DType.INT64)) return EnumSet.of(DType.INT64);
else if (types.contains(DType.INT32)) return EnumSet.of(DType.INT32);

throw new IllegalStateException(
"Expected at least one numeric dtype for range(), but got: " + types + ".");
}

@Override
protected Set<List<Dimension<?>>> getDefaultShapes(PropagationCallGraphBuilder builder) {
throw new UnsupportedOperationException(
"Shapes for range() are derived from mandatory numeric arguments and must be provided"
+ " explicitly.");
}

@Override
protected int getValueNumberForShapeArgument() {
throw new UnsupportedOperationException(
"Range does not have a shape argument. Its shape is derived from the numeric arguments.");
}

@Override
protected int getValueNumberForDTypeArgument() {
// TODO: We need a value number for the dtype argument. Also, that value number can differ
// depending on the version of the `range` function being called.

return -1; // Positional dtype argument for range() is not yet implemented.
}
}
Loading
Loading