diff --git a/rir/src/api.cpp b/rir/src/api.cpp index 22570da6d..487dce219 100644 --- a/rir/src/api.cpp +++ b/rir/src/api.cpp @@ -601,10 +601,7 @@ REXPORT SEXP rirCreateSimpleIntContext() { return res; } -REXPORT SEXP playground() { - - return R_NilValue; -} +REXPORT SEXP playground() { return R_NilValue; } bool startup() { initializeRuntime(); diff --git a/rir/src/compiler/native/builtins.cpp b/rir/src/compiler/native/builtins.cpp index bd3f627c9..74513b669 100644 --- a/rir/src/compiler/native/builtins.cpp +++ b/rir/src/compiler/native/builtins.cpp @@ -2247,7 +2247,7 @@ bool deoptChaosTriggerImpl(bool deoptTrue) { void checkTypeImpl(SEXP val, uint64_t type, const char* msg) { assert(pir::Parameter::RIR_CHECK_PIR_TYPES); pir::PirType typ(type); - if (!typ.isInstance(val)) { + if (!typ.isInstance(val, true)) { std::cerr << "type assert failed\n"; std::cerr << "got " << pir::PirType(val) << " but expected a " << typ << ":\n"; diff --git a/rir/src/compiler/pir/instruction.cpp b/rir/src/compiler/pir/instruction.cpp index e18e28c05..42947010d 100644 --- a/rir/src/compiler/pir/instruction.cpp +++ b/rir/src/compiler/pir/instruction.cpp @@ -877,6 +877,19 @@ PirType CallSafeBuiltin::inferType(const Instruction::GetType& getType) const { inferred = PirType(RType::vec).orAttribsOrObj(); } + if ("dim" == name) { + if (!getType(callArg(0).val()) + .maybeObj()) { // TODO: is it necessary to check this? + // inferred = (PirType(RType::integer) | RType::nil) + // .notMissing() + // .orAttribsOrObj(); + + // if (!getType(callArg(0).val()).maybeHasAttrs()) { + // inferred = PirType(RType::nil); + // } + } + } + if (inferred != PirType::bottom()) return inferred & type; diff --git a/rir/src/compiler/pir/type.cpp b/rir/src/compiler/pir/type.cpp index a5a8b19ee..8d6c0aea3 100644 --- a/rir/src/compiler/pir/type.cpp +++ b/rir/src/compiler/pir/type.cpp @@ -95,14 +95,17 @@ void PirType::merge(SEXPTYPE sexptype) { // contain NaN for the benefit, so we simple assume they do static const R_xlen_t MAX_SIZE_OF_VECTOR_FOR_NAN_CHECK = 1; -static bool maybeContainsNAOrNaN(SEXP vector) { +static bool maybeContainsNAOrNaN(SEXP vector, bool exact) { if (TYPEOF(vector) == CHARSXP) { return vector == NA_STRING; } else if (TYPEOF(vector) == INTSXP || TYPEOF(vector) == REALSXP || TYPEOF(vector) == LGLSXP || TYPEOF(vector) == CPLXSXP || TYPEOF(vector) == STRSXP) { - if (XLENGTH(vector) > MAX_SIZE_OF_VECTOR_FOR_NAN_CHECK) { - return true; + + if (!exact) { + if (XLENGTH(vector) > MAX_SIZE_OF_VECTOR_FOR_NAN_CHECK) { + return true; + } } for (int i = 0; i < XLENGTH(vector); i++) { switch (TYPEOF(vector)) { @@ -137,7 +140,7 @@ static bool maybeContainsNAOrNaN(SEXP vector) { } } -PirType::PirType(SEXP e) : flags_(topRTypeFlags()), t_(RTypeSet()) { +PirType::PirType(SEXP e, bool exact) : flags_(topRTypeFlags()), t_(RTypeSet()) { if (e == R_MissingArg) { *this = theMissingValue(); @@ -179,7 +182,7 @@ PirType::PirType(SEXP e) : flags_(topRTypeFlags()), t_(RTypeSet()) { if (t != LISTSXP && t != EXTERNALSXP && t != BCODESXP && t != LANGSXP) if (Rf_xlength(e) == 1) flags_.reset(TypeFlags::maybeNotScalar); - if (!maybeContainsNAOrNaN(e)) + if (!maybeContainsNAOrNaN(e, exact)) flags_.reset(TypeFlags::maybeNAOrNaN); } @@ -210,14 +213,14 @@ void PirType::merge(const ObservedValues& other) { *this = orSexpTypes(any()); } -bool PirType::isInstance(SEXP val) const { +bool PirType::isInstance(SEXP val, bool exact) const { if (isRType()) { if (TYPEOF(val) == PROMSXP) { assert(!Rf_isObject(val)); if (maybePromiseWrapped() && !maybeLazy()) { auto v = PRVALUE(val); return v != R_UnboundValue && - notMissing().forced().isInstance(v); + notMissing().forced().isInstance(v, exact); } return maybe(RType::prom) || (maybeLazy() && maybePromiseWrapped()); } @@ -229,7 +232,9 @@ bool PirType::isInstance(SEXP val) const { return IS_SIMPLE_SCALAR(val, LGLSXP) && LOGICAL(val)[0] != NA_LOGICAL; } - return PirType(val).isA(*this); + + return PirType(val, exact).isA(*this); + ; } else { std::cerr << "can't check val is instance of " << *this << ", value:\n"; Rf_PrintValue(val); diff --git a/rir/src/compiler/pir/type.h b/rir/src/compiler/pir/type.h index 5202344ff..26ffe1888 100644 --- a/rir/src/compiler/pir/type.h +++ b/rir/src/compiler/pir/type.h @@ -179,7 +179,7 @@ struct PirType { // cppcheck-suppress noExplicitConstructor constexpr PirType(const NativeTypeSet& t) : t_(t) {} - explicit PirType(SEXP); + explicit PirType(SEXP, bool exact = false); constexpr PirType(const PirType& other) : flags_(other.flags_), t_(other.t_) {} @@ -641,7 +641,8 @@ struct PirType { // NULL return RType::nil; } - if (isA((num() | RType::str | RType::list | RType::code) + if (isA((num() | RType::str | RType::list | + RType::code /* | RType::nil */) .orAttribsOrObj())) { // If the index is out of bounds, NA is returned (even if both args // are non-NA) so we must add orNAOrNaN() @@ -750,7 +751,7 @@ struct PirType { } // Is val an instance of this type? - bool isInstance(SEXP val) const; + bool isInstance(SEXP val, bool exact = false) const; void print(std::ostream& out = std::cout) const; diff --git a/rir/src/compiler/pir/value.cpp b/rir/src/compiler/pir/value.cpp index bd9380e56..19513c97e 100644 --- a/rir/src/compiler/pir/value.cpp +++ b/rir/src/compiler/pir/value.cpp @@ -56,6 +56,7 @@ void Value::callArgTypeToContext(Context& assumptions, unsigned i) const { arg = arg->cFollowCastsAndForce(); if (!MkArg::Cast(arg)) check(arg); + } void Value::checkReplace(Value* replace) const {