From 1f6bdec195352ad7869ec8a652f6c02c21b14926 Mon Sep 17 00:00:00 2001 From: Ryan Keck Date: Mon, 5 Jan 2026 10:54:54 -0500 Subject: [PATCH 1/4] Closes 4570: Simplify and extend logic in doBigIntBinOpvs and doBigIntBinOpvsBoolReturn --- src/BinOp.chpl | 1172 ++++++++++++++++++------------------------ src/OperatorMsg.chpl | 272 +++++++++- 2 files changed, 763 insertions(+), 681 deletions(-) diff --git a/src/BinOp.chpl b/src/BinOp.chpl index 7b670efeadf..65c1c42612a 100644 --- a/src/BinOp.chpl +++ b/src/BinOp.chpl @@ -18,83 +18,84 @@ module BinOp const omLogger = new Logger(logLevel, logChannel); proc splitType(type dtype) param : int { - // 0 -> bool, 1 -> uint, 2 -> int, 3 -> real - - if dtype == bool then return 0; - else if dtype == uint(8) then return 1; - else if dtype == uint(16) then return 1; - else if dtype == uint(32) then return 1; - else if dtype == uint(64) then return 1; - else if dtype == int(8) then return 2; - else if dtype == int(16) then return 2; - else if dtype == int(32) then return 2; - else if dtype == int(64) then return 2; - else if dtype == real(32) then return 3; - else if dtype == real(64) then return 3; - else return 0; + // 0 -> bool, 1 -> uint, 2 -> int, 3 -> real + + if dtype == bool then return 0; + else if dtype == uint(8) then return 1; + else if dtype == uint(16) then return 1; + else if dtype == uint(32) then return 1; + else if dtype == uint(64) then return 1; + else if dtype == int(8) then return 2; + else if dtype == int(16) then return 2; + else if dtype == int(32) then return 2; + else if dtype == int(64) then return 2; + else if dtype == real(32) then return 3; + else if dtype == real(64) then return 3; + else if dtype == bigint then return 4; + else return 0; - } + } - proc mySafeCast(type dtype1, type dtype2) type { - param typeKind1 = splitType(dtype1); - param bitSize1 = if dtype1 == bool then 8 else numBits(dtype1); - param typeKind2 = splitType(dtype2); - param bitSize2 = if dtype2 == bool then 8 else numBits(dtype2); - - if typeKind1 == 2 && typeKind2 == 1 && bitSize1 <= bitSize2 { - select bitSize2 { - when 64 { return real(64); } - when 32 { return int(64); } - when 16 { return int(32); } - when 8 { return int(16); } - } + proc mySafeCast(type dtype1, type dtype2) type { + param typeKind1 = splitType(dtype1); + param bitSize1 = if dtype1 == bool then 8 else numBits(dtype1); + param typeKind2 = splitType(dtype2); + param bitSize2 = if dtype2 == bool then 8 else numBits(dtype2); + + if typeKind1 == 2 && typeKind2 == 1 && bitSize1 <= bitSize2 { + select bitSize2 { + when 64 { return real(64); } + when 32 { return int(64); } + when 16 { return int(32); } + when 8 { return int(16); } } + } - if typeKind2 == 2 && typeKind1 == 1 && bitSize2 <= bitSize1 { - select bitSize1 { - when 64 { return real(64); } - when 32 { return int(64); } - when 16 { return int(32); } - when 8 { return int(16); } - } + if typeKind2 == 2 && typeKind1 == 1 && bitSize2 <= bitSize1 { + select bitSize1 { + when 64 { return real(64); } + when 32 { return int(64); } + when 16 { return int(32); } + when 8 { return int(16); } } + } - if dtype1 == real(32) && (dtype2 == int(32) || dtype2 == uint(32)) { - return real(64); - } + if dtype1 == real(32) && (dtype2 == int(32) || dtype2 == uint(32)) { + return real(64); + } - if dtype2 == real(32) && (dtype1 == int(32) || dtype1 == uint(32)) { - return real(64); - } + if dtype2 == real(32) && (dtype1 == int(32) || dtype1 == uint(32)) { + return real(64); + } - if typeKind1 == 3 || typeKind2 == 3 { - select max(bitSize1, bitSize2) { - when 64 { return real(64); } - when 32 { return real(32); } - } + if typeKind1 == 3 || typeKind2 == 3 { + select max(bitSize1, bitSize2) { + when 64 { return real(64); } + when 32 { return real(32); } } + } - if typeKind1 == 2 || typeKind2 == 2 { - select max(bitSize1, bitSize2) { - when 64 { return int(64); } - when 32 { return int(32); } - when 16 { return int(16); } - when 8 { return int(8); } - } + if typeKind1 == 2 || typeKind2 == 2 { + select max(bitSize1, bitSize2) { + when 64 { return int(64); } + when 32 { return int(32); } + when 16 { return int(16); } + when 8 { return int(8); } } + } - if typeKind1 == 1 || typeKind2 == 1 { - select max(bitSize1, bitSize2) { - when 64 { return uint(64); } - when 32 { return uint(32); } - when 16 { return uint(16); } - when 8 { return uint(8); } - } + if typeKind1 == 1 || typeKind2 == 1 { + select max(bitSize1, bitSize2) { + when 64 { return uint(64); } + when 32 { return uint(32); } + when 16 { return uint(16); } + when 8 { return uint(8); } } + } - return bool; + return bool; - } + } /* Helper function to ensure that floor division cases are handled in accordance with numpy @@ -225,7 +226,7 @@ module BinOp :returns: (MsgTuple) :throws: `UndefinedSymbolError(name)` */ - proc doBinOpvv(l, r, type lType, type rType, type etype, op: string, pn, st): MsgTuple throws { + proc _doBinOpVLeftVec(l, r, type lType, type rType, type etype, op: string, pn, st): MsgTuple throws { var e = makeDistArray((...l.tupShape), etype); const nie = notImplementedError(pn,l.dtype,op,r.dtype); @@ -358,57 +359,65 @@ module BinOp } - proc doBinOpvs(l, val, type lType, type rType, type etype, op: string, pn, st): MsgTuple throws { - var e = makeDistArray((...l.tupShape), etype); + // Public entrypoints (vv/vs/sv) delegate to vector-left kernels. + proc doBinOpvv(l, r, type lType, type rType, type etype, op: string, pn, st): MsgTuple throws { + return _doBinOpVLeftVec(l, r, lType, rType, etype, op, pn, st); + } + + proc _doBinOpVLeftScalar(vec, val, type lType, type rType, type etype, + op: string, pn, st, + param swapOperands: bool = false): MsgTuple throws { + var e = makeDistArray((...vec.tupShape), etype); - const nie = notImplementedError(pn,"%s %s %s".format(type2str(l.a.eltType),op,type2str(val.type))); + const nie = if !swapOperands + then notImplementedError(pn,"%s %s %s".format(type2str(vec.a.eltType), op, type2str(val.type))) + else notImplementedError(pn,"%s %s %s".format(type2str(val.type), op, type2str(vec.a.eltType))); type castType = mySafeCast(lType, rType); - // The compiler complains that maybe etype is bool if it gets down below this - // without returning, so we have to kind of chunk this next piece off. - - // For similar reasons, everything else is kinda split off into its own thing. - // The compiler has no common sense about things (and that's not really its fault) - if etype == bool { if boolOps.contains(op) { - select op { - - when "<" { e = (l.a: castType) < (val: castType); } - when "<=" { e = (l.a: castType) <= (val: castType); } - when ">" { e = (l.a: castType) > (val: castType); } - when ">=" { e = (l.a: castType) >= (val: castType); } - when "==" { e = (l.a: castType) == (val: castType); } - when "!=" { e = (l.a: castType) != (val: castType); } - otherwise do return MsgTuple.error(nie); // Shouldn't happen - + when "<" { e = if !swapOperands then (vec.a: castType) < (val: castType) + else (val: castType) < (vec.a: castType); } + when "<=" { e = if !swapOperands then (vec.a: castType) <= (val: castType) + else (val: castType) <= (vec.a: castType); } + when ">" { e = if !swapOperands then (vec.a: castType) > (val: castType) + else (val: castType) > (vec.a: castType); } + when ">=" { e = if !swapOperands then (vec.a: castType) >= (val: castType) + else (val: castType) >= (vec.a: castType); } + when "==" { e = if !swapOperands then (vec.a: castType) == (val: castType) + else (val: castType) == (vec.a: castType); } + when "!=" { e = if !swapOperands then (vec.a: castType) != (val: castType) + else (val: castType) != (vec.a: castType); } + otherwise do return MsgTuple.error(nie); } return st.insert(new shared SymEntry(e)); } if lType == bool && rType == bool { - if !doBoolBoolBitOp(op, e, l.a, val) { + // Bit-ops on bools are commutative for the supported operators. + if !doBoolBoolBitOp(op, e, vec.a, val) { return MsgTuple.error(nie); } return st.insert(new shared SymEntry(e)); } return MsgTuple.error(nie); - } - else if lType == bool && rType == bool && etype == uint(8) { // Both bools is kinda weird + else if lType == bool && rType == bool && etype == uint(8) { + // Both bools is kinda weird. select op { - when "%" { e = (0: uint(8)); } // numpy has these as int(8), but Arkouda doesn't really support that type. - when "//" { e = (l.a & val): uint(8); } - when "**" { e = (!l.a & val): uint(8); } - when "<<" { e = (l.a: uint(8)) << (val: uint(8)); } - when ">>" { e = (l.a: uint(8)) >> (val: uint(8)); } + when "%" { e = (0: uint(8)); } + when "//" { e = if !swapOperands then (vec.a & val): uint(8) else (val & vec.a): uint(8); } + when "**" { e = if !swapOperands then (!vec.a & val): uint(8) else (!val & vec.a): uint(8); } + when "<<" { e = if !swapOperands then (vec.a: uint(8)) << (val: uint(8)) + else (val: uint(8)) << (vec.a: uint(8)); } + when ">>" { e = if !swapOperands then (vec.a: uint(8)) >> (val: uint(8)) + else (val: uint(8)) >> (vec.a: uint(8)); } otherwise do return MsgTuple.error(nie); - // >>> and <<< could probably be implemented as int(8) or uint(8) things } return st.insert(new shared SymEntry(e)); } @@ -416,197 +425,115 @@ module BinOp else if etype == real(32) || etype == real(64) { select op { - when "*" { e = (l.a: etype * val: etype): etype; } - when "+" { e = (l.a: etype + val: etype): etype; } - when "-" { e = (l.a: etype - val: etype): etype; } - when "/" { e = (l.a: etype / val: etype): etype; } + when "*" { e = (vec.a: etype * val: etype): etype; } // commutative + when "+" { e = (vec.a: etype + val: etype): etype; } // commutative + when "-" { e = if !swapOperands then (vec.a: etype - val: etype): etype + else (val: etype - vec.a: etype): etype; } + when "/" { e = if !swapOperands then (vec.a: etype / val: etype): etype + else (val: etype / vec.a: etype): etype; } when "%" { ref ea = e; - ref la = l.a; - [(ei,li) in zip(ea,la)] ei = modHelper(li: etype, val: etype): etype; + ref va = vec.a; + if !swapOperands { + [(ei, li) in zip(ea, va)] ei = modHelper(li: etype, val: etype): etype; + } else { + [(ei, ri) in zip(ea, va)] ei = modHelper(val: etype, ri: etype): etype; + } } when "//" { ref ea = e; - ref la = l.a; - [(ei,li) in zip(ea,la)] ei = floorDivisionHelper(li: etype, val: etype): etype; + ref va = vec.a; + if !swapOperands { + [(ei, li) in zip(ea, va)] ei = floorDivisionHelper(li: etype, val: etype): etype; + } else { + [(ei, ri) in zip(ea, va)] ei = floorDivisionHelper(val: etype, ri: etype): etype; + } } when "**" { - e = ((l.a: etype) ** (val: etype)): etype; + e = if !swapOperands then ((vec.a: etype) ** (val: etype)): etype + else ((val: etype) ** (vec.a: etype)): etype; } otherwise do return MsgTuple.error(nie); } return st.insert(new shared SymEntry(e)); - } else { - select op { - when "|" { e = (l.a | val): etype; } - when "&" { e = (l.a & val): etype; } - when "*" { e = (l.a * val): etype; } - when "^" { e = (l.a ^ val): etype; } - when "+" { e = (l.a + val): etype; } - when "-" { e = (l.a - val): etype; } - when "/" { e = (l.a: etype) / (val: etype); } + when "|" { e = (vec.a | val): etype; } + when "&" { e = (vec.a & val): etype; } + when "*" { e = (vec.a * val): etype; } + when "^" { e = (vec.a ^ val): etype; } + when "+" { e = (vec.a + val): etype; } + when "-" { e = if !swapOperands then (vec.a - val): etype else (val - vec.a): etype; } + when "/" { e = if !swapOperands then (vec.a: etype) / (val: etype) + else (val: etype) / (vec.a: etype); } when "%" { ref ea = e; - ref la = l.a; - [(ei,li) in zip(ea,la)] ei = if val != 0 then li%val else 0; + ref va = vec.a; + if !swapOperands { + [(ei, li) in zip(ea, va)] ei = if val != 0 then li%val else 0; + } else { + [(ei, ri) in zip(ea, va)] ei = if ri != 0 then val%ri else 0; + } } when "//" { ref ea = e; - ref la = l.a; - [(ei,li) in zip(ea,la)] ei = if val != 0 then (li/val): etype else 0: etype; + ref va = vec.a; + if !swapOperands { + [(ei, li) in zip(ea, va)] ei = if val != 0 then (li/val): etype else 0: etype; + } else { + [(ei, ri) in zip(ea, va)] ei = if ri != 0 then (val/ri): etype else 0: etype; + } } when "**" { - if val < 0 - then return MsgTuple.error("Attempt to exponentiate base of type Int or UInt to negative exponent"); - e = (l.a: etype) ** (val: etype); + if !swapOperands { + if val < 0 then return MsgTuple.error("Attempt to exponentiate base of type Int or UInt to negative exponent"); + e = (vec.a: etype) ** (val: etype); + } else { + if || reduce (vec.a < 0) then return MsgTuple.error("Attempt to exponentiate base of type Int or UInt to negative exponent"); + e = (val: etype) ** (vec.a: etype); + } } when "<<" { ref ea = e; - ref la = l.a; - [(ei,li) in zip(ea,la)] if (0 <= val && val < numBits(etype)) then ei = ((li: etype) << (val: etype)): etype; + ref va = vec.a; + if !swapOperands { + [(ei, li) in zip(ea, va)] if (0 <= val && val < numBits(etype)) then ei = ((li: etype) << (val: etype)): etype; + } else { + [(ei, ri) in zip(ea, va)] if (0 <= ri && ri < numBits(etype)) then ei = ((val: etype) << (ri: etype)): etype; + } } when ">>" { ref ea = e; - ref la = l.a; - [(ei,li) in zip(ea,la)] if (0 <= val && val < numBits(etype)) then ei = ((li: etype) >> (val: etype)): etype; - } - when "<<<" { e = rotl(l.a: etype, val: etype); } - when ">>>" { e = rotr(l.a: etype, val: etype); } - otherwise do return MsgTuple.error(nie); - } - return st.insert(new shared SymEntry(e)); - } - - return MsgTuple.error(nie); - } - - proc doBinOpsv(val, r, type lType, type rType, type etype, op: string, pn, st) throws { - var e = makeDistArray((...r.tupShape), etype); - const nie = notImplementedError(pn,"%s %s %s".format(type2str(val.type),op,type2str(r.a.eltType))); - - type castType = mySafeCast(lType, rType); - - // The compiler complains that maybe etype is bool if it gets down below this - // without returning, so we have to kind of chunk this next piece off. - - // For similar reasons, everything else is kinda split off into its own thing. - // The compiler has no common sense about things (and that's not really its fault) - - if etype == bool { - - if boolOps.contains(op) { - - select op { - - when "<" { e = (val: castType) < (r.a: castType); } - when "<=" { e = (val: castType) <= (r.a: castType); } - when ">" { e = (val: castType) > (r.a: castType); } - when ">=" { e = (val: castType) >= (r.a: castType); } - when "==" { e = (val: castType) == (r.a: castType); } - when "!=" { e = (val: castType) != (r.a: castType); } - otherwise do return MsgTuple.error(nie); // Shouldn't happen - - } - return st.insert(new shared SymEntry(e)); - } - - if lType == bool && rType == bool { - if !doBoolBoolBitOp(op, e, r.a, val) { - return MsgTuple.error(nie); - } - return st.insert(new shared SymEntry(e)); - } - - return MsgTuple.error(nie); - - } - - else if lType == bool && rType == bool && etype == uint(8) { // Both bools is kinda weird - select op { - when "%" { e = (0: uint(8)); } // numpy has these as int(8), but Arkouda doesn't really support that type. - when "//" { e = (val & r.a): uint(8); } - when "**" { e = (!val & r.a): uint(8); } - when "<<" { e = (val: uint(8)) << (r.a: uint(8)); } - when ">>" { e = (val: uint(8)) >> (r.a: uint(8)); } - otherwise do return MsgTuple.error(nie); - // >>> and <<< could probably be implemented as int(8) or uint(8) things - } - return st.insert(new shared SymEntry(e)); - } - - else if etype == real(32) || etype == real(64) { - - select op { - when "*" { e = (val: etype * r.a: etype): etype; } - when "+" { e = (val: etype + r.a: etype): etype; } - when "-" { e = (val: etype - r.a: etype): etype; } - when "/" { e = (val: etype / r.a: etype): etype; } - when "%" { - ref ea = e; - ref ra = r.a; - [(ei,ri) in zip(ea,ra)] ei = modHelper(val: etype, ri: etype): etype; + ref va = vec.a; + if !swapOperands { + [(ei, li) in zip(ea, va)] if (0 <= val && val < numBits(etype)) then ei = ((li: etype) >> (val: etype)): etype; + } else { + [(ei, ri) in zip(ea, va)] if (0 <= ri && ri < numBits(etype)) then ei = ((val: etype) >> (ri: etype)): etype; + } } - when "//" { - ref ea = e; - ref ra = r.a; - [(ei,ri) in zip(ea,ra)] ei = floorDivisionHelper(val: etype, ri: etype): etype; + when "<<<" { + e = if !swapOperands then rotl(vec.a: etype, val: etype) + else rotl(val: etype, vec.a: etype); } - when "**" { - e = ((val: etype) ** (r.a: etype)): etype; + when ">>>" { + e = if !swapOperands then rotr(vec.a: etype, val: etype) + else rotr(val: etype, vec.a: etype); } otherwise do return MsgTuple.error(nie); } return st.insert(new shared SymEntry(e)); - } + } - else { - - select op { - when "|" { e = (val | r.a): etype; } - when "&" { e = (val & r.a): etype; } - when "*" { e = (val * r.a): etype; } - when "^" { e = (val ^ r.a): etype; } - when "+" { e = (val + r.a): etype; } - when "-" { e = (val - r.a): etype; } - when "/" { e = (val: etype) / (r.a: etype); } - when "%" { - ref ea = e; - ref ra = r.a; - [(ei,ri) in zip(ea,ra)] ei = if ri != 0 then val%ri else 0; - } - when "//" { - ref ea = e; - ref ra = r.a; - [(ei,ri) in zip(ea,ra)] ei = if ri != 0 then (val/ri): etype else 0: etype; - } - when "**" { - if || reduce (r.a<0) - then return MsgTuple.error("Attempt to exponentiate base of type Int or UInt to negative exponent"); - e = (val: etype) ** (r.a: etype); - } - when "<<" { - ref ea = e; - ref ra = r.a; - [(ei,ri) in zip(ea,ra)] if (0 <= ri && ri < numBits(etype)) then ei = ((val: etype) << (ri: etype)): etype; - } - when ">>" { - ref ea = e; - ref ra = r.a; - [(ei,ri) in zip(ea,ra)] if (0 <= ri && ri < numBits(etype)) then ei = ((val: etype) >> (ri: etype)): etype; - } - when "<<<" { e = rotl(val: etype, r.a: etype); } - when ">>>" { e = rotr(val: etype, r.a: etype); } - otherwise do return MsgTuple.error(nie); - } - return st.insert(new shared SymEntry(e)); - } + proc doBinOpvs(l, val, type lType, type rType, type etype, op: string, pn, st): MsgTuple throws { + return _doBinOpVLeftScalar(l, val, lType, rType, etype, op, pn, st, false); + } - return MsgTuple.error(nie); + proc doBinOpsv(val, r, type lType, type rType, type etype, op: string, pn, st): MsgTuple throws { + // Guarantee a vector on the left; scalar-left is handled by swapping operands in the scalar kernel. + return _doBinOpVLeftScalar(r, val, lType, rType, etype, op, pn, st, true); } proc doBigIntBinOpvv(l, r, op: string) throws { @@ -925,524 +852,443 @@ module BinOp } } - proc doBigIntBinOpvs(l, val, op: string) throws { - var max_bits = l.max_bits; + + + proc _doBigIntBinOpVLeftScalar(vec, val, op: string, param swapOperands: bool = false) throws { + // swapOperands == false: vec OP val + // swapOperands == true: val OP vec (but we still "own" vec's shape) + + // Match doBigIntBinOpvv's assumptions: max_bits is driven by the array entries. + var max_bits = vec.max_bits; var max_size = 1:bigint; var has_max_bits = max_bits != -1; if has_max_bits { max_size <<= max_bits; max_size -= 1; } - ref la = l.a; - var tmp = if l.etype == bigint then la else la:bigint; - // these cases are not mutually exclusive, - // so we have a flag to track if tmp is ever populated - var visted = false; + + ref veca = vec.a; + var tmp = if vec.etype == bigint then veca else veca:bigint; + + // scalar value as bigint (for bitwise / add / mul, etc.) + const ri: bigint = val:bigint; // had to create bigint specific BinOp procs which return // the distributed array because we need it at SymEntry creation time - if l.etype == bigint && val.type == bigint { - // first we try the ops that only work with - // both being bigint + if !swapOperands && vec.etype == bigint && val.type != bigint && smallOps.contains(op) { + // ops that only work with a left hand side of bigint and right hand side non-bigint + // Just bitshifts and exponentiation without local_max_size select op { - when "&" { - forall t in tmp with (var local_val = val, var local_max_size = max_size) { - t &= local_val; - if has_max_bits { - t &= local_max_size; - } - } - visted = true; + when "<<" { + forall t in tmp do + t = if has_max_bits && val >= max_bits then 0: bigint else t << val; } - when "|" { - forall t in tmp with (var local_val = val, var local_max_size = max_size) { - t |= local_val; - if has_max_bits { - t &= local_max_size; - } - } - visted = true; + when ">>" { + forall t in tmp do + t = if has_max_bits && val >= max_bits then 0: bigint else t >> val; } - when "^" { - forall t in tmp with (var local_val = val, var local_max_size = max_size) { - t ^= local_val; - if has_max_bits { - t &= local_max_size; + when "**" { + if val < 0 { // This could actually lead into real number territory + throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); + } + if has_max_bits { + forall t in tmp with (var local_max_size = max_size) { + powMod(t, t, val, local_max_size + 1); } } - visted = true; + else { + forall t in tmp do t = t ** val; + } } - when "/" { - forall t in tmp with (var local_val = val, var local_max_size = max_size) { - t /= local_val; - if has_max_bits { - t &= local_max_size; + } + } else if swapOperands && val.type == bigint && vec.etype != bigint && smallOps.contains(op) { + // ops that only work with a left hand side of bigint and right hand side non-bigint + // Just bitshifts and exponentiation without local_max_size + select op { + when "<<" { + forall (t, v) in zip(tmp, veca) do + t = if has_max_bits && v >= max_bits then 0: bigint else val << v; + } + when ">>" { + forall (t, v) in zip(tmp, veca) do + t = if has_max_bits && v >= max_bits then 0: bigint else val >> v; + } + when "**" { + if || reduce (tmp<0) { // This could actually lead into real number territory + throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); + } + if has_max_bits { + forall (t, v) in zip(tmp, veca) with (var local_max_size = max_size) { + powMod(t, val, v, local_max_size + 1); } } - visted = true; + else { + forall (t, v) in zip(tmp, veca) do t = val ** v; + } } } - } - if l.etype == bigint && (val.type == bigint || val.type == int || val.type == uint) { - // then we try the ops that only work with a - // left hand side of bigint - if val.type != bigint { - // can't shift a bigint by a bigint - select op { - when "<<" { - if has_max_bits && val >= max_bits { - forall t in tmp with (var local_zero = 0:bigint) { - t = local_zero; + } else { + select op { + when "&" { forall t in tmp do t &= ri; } + when "|" { forall t in tmp do t |= ri; } + when "^" { forall t in tmp do t ^= ri; } + when "+" { forall t in tmp do t += ri; } + when "*" { forall t in tmp do t *= ri; } + when "-" { + if !swapOperands { + forall t in tmp do t -= ri; + } else { + forall t in tmp do t = ri - t; + } + } + when "//" { + if !swapOperands { + forall t in tmp { + if ri != 0 { + var q: bigint; + // floor-style integer division, like Python's // + div(q, t, ri, roundingMode.down); + t = q; + } else { + // whatever semantics you want for division by zero: + t = 0: bigint; } } - else { - forall t in tmp with (var local_val = val, var local_max_size = max_size) { - t <<= local_val; - if has_max_bits { - t &= local_max_size; - } + } else { + forall t in tmp { + const denom: bigint = t; // <- cast bool/int/uint/etc to bigint + if denom != 0 { + var q: bigint; + // floor-style integer division, like Python's // + div(q, ri, denom, roundingMode.down); + t = q; + } else { + // whatever semantics you want for division by zero: + t = 0: bigint; } } - visted = true; } - when ">>" { - if has_max_bits && val >= max_bits { - forall t in tmp with (var local_zero = 0:bigint) { - t = local_zero; + } + when "%" { + if !swapOperands { + forall t in tmp { + if ri != 0 { + mod(t, t, ri); + } else { + t = 0: bigint; } } - else { - forall t in tmp with (var local_max_size = max_size) { - t >>= val; - if has_max_bits { - t &= local_max_size; - } + } else { + forall t in tmp { + if t != 0 { + mod(t, ri, t); + } else { + t = 0: bigint; } } - visted = true; } - when "<<<" { - if !has_max_bits { - throw new Error("Must set max_bits to rotl"); + } + when "**" { + if !swapOperands { + if ri < 0 { + throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); } - var botBits = la; - var modded_shift = if val.type == int then val % max_bits else val % max_bits:uint; - var shift_amt = if val.type == int then max_bits - modded_shift else max_bits:uint - modded_shift; - forall (t, bot_bits) in zip(tmp, botBits) with (var local_val = modded_shift, var local_shift_amt = shift_amt, var local_max_size = max_size) { - t <<= local_val; - bot_bits >>= local_shift_amt; - t += bot_bits; - t &= local_max_size; + if has_max_bits { + forall t in tmp with (var local_max_size = max_size) { + powMod(t, t, ri, local_max_size + 1); + } } - visted = true; - } - when ">>>" { - if !has_max_bits { - throw new Error("Must set max_bits to rotr"); + else { + throw new Error("Attempt to exponentiate base of type BigInt to BigInt without max_bits"); } - var topBits = la; - var modded_shift = if val.type == int then val % max_bits else val % max_bits:uint; - var shift_amt = if val.type == int then max_bits - modded_shift else max_bits:uint - modded_shift; - forall (t, tB) in zip(tmp, topBits) with (var local_val = modded_shift, var local_shift_amt = shift_amt, var local_max_size = max_size) { - t >>= local_val; - tB <<= local_shift_amt; - t += tB; - t &= local_max_size; + } else { + if || reduce (tmp<0) { + throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); } - visted = true; - } - } - } - select op { - when "//" { // floordiv - forall t in tmp with (var local_val = val, var local_max_size = max_size) { - if local_val != 0 { - t /= local_val; + if has_max_bits { + forall t in tmp with (var local_max_size = max_size) { + powMod(t, ri, t, local_max_size + 1); + } } else { - t = 0:bigint; - } - if has_max_bits { - t &= local_max_size; + throw new Error("Attempt to exponentiate base of type BigInt to BigInt without max_bits"); } } - visted = true; } - when "%" { // modulo " <- quote is workaround for syntax highlighter bug - // we only do in place mod when val != 0, tmp will be 0 in other locations - // we can't use ei = li % val because this can result in negatives - forall t in tmp with (var local_val = val, var local_max_size = max_size) { - if local_val != 0 { - mod(t, t, local_val); - } - else { - t = 0:bigint; + when "<<<" { + if !swapOperands { + if !has_max_bits { + throw new Error("Must set max_bits to rotl"); } - if has_max_bits { - t &= local_max_size; + var modded_shift = 0: bigint; + forall (t, bot_bits) in zip(tmp, veca) with (var loc_modded_shift = modded_shift) { + mod(loc_modded_shift, ri, max_bits); + t <<= loc_modded_shift: int; + const shift_amt = max_bits - loc_modded_shift: int; + t += bot_bits >> shift_amt; } - } - visted = true; - } - when "**" { - if val<0 { - throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); - } - if has_max_bits { - forall t in tmp with (var local_val = val, var local_max_size = max_size) { - powMod(t, t, local_val, local_max_size + 1); + } else { + if !has_max_bits { + throw new Error("Must set max_bits to rotl"); } - } - else { - forall t in tmp with (var local_val = val) { - t **= local_val:uint; + var botBits = val: bigint; + var modded_shift = 0: bigint; + forall (t, v) in zip(tmp, veca) with (var loc_modded_shift = modded_shift) { + mod(loc_modded_shift, v: bigint, max_bits); + t <<= loc_modded_shift: int; + var shift_amt = max_bits - loc_modded_shift: int; + t += botBits >> shift_amt; } } - visted = true; } - } - } - if (l.etype == bigint && val.type == bigint) || - (l.etype == bigint && (val.type == int || val.type == uint || val.type == bool)) || - (val.type == bigint && (l.etype == int || l.etype == uint || l.etype == bool)) { - select op { - when "+" { - forall t in tmp with (var local_val = val, var local_max_size = max_size) { - t += local_val; - if has_max_bits { - t &= local_max_size; + when ">>>" { + if !swapOperands { + if !has_max_bits { + throw new Error("Must set max_bits to rotr"); } - } - visted = true; - } - when "-" { - forall t in tmp with (var local_val = val, var local_max_size = max_size) { - t -= local_val; - if has_max_bits { - t &= local_max_size; + var modded_shift = 0: bigint; + forall (t, top_bits) in zip(tmp, veca) with (var loc_modded_shift = modded_shift) { + mod(loc_modded_shift, ri, max_bits); + t >>= loc_modded_shift: int; + const shift_amt = max_bits - loc_modded_shift: int; + t += top_bits << shift_amt; } - } - visted = true; - } - when "*" { - forall t in tmp with (var local_val = val, var local_max_size = max_size) { - t *= local_val; - if has_max_bits { - t &= local_max_size; + } else { + if !has_max_bits { + throw new Error("Must set max_bits to rotr"); + } + var topBits = val: bigint; + var modded_shift = 0: bigint; + forall (t, v) in zip(tmp, veca) with (var loc_modded_shift = modded_shift) { + mod(loc_modded_shift, v: bigint, max_bits); + t >>= loc_modded_shift: int; + var shift_amt = max_bits - loc_modded_shift: int; + t += topBits << shift_amt; } } - visted = true; } + otherwise do throw new Error("Unsupported operation: " + (if !swapOperands then (vec.etype:string +" "+ op +" "+ val.type:string) + else (val.type:string +" "+ op +" "+ vec.etype:string))); } } - if !visted { - throw new Error("Unsupported operation: " + l.etype:string +" "+ op +" "+ val.type:string); - } + + if has_max_bits then forall t in tmp with (const local_max_size = max_size) do t &= local_max_size; + return (tmp, max_bits); } + proc _doBigIntBinOpVLeftScalarBoolReturn(vec, val, op: string, param swapOperands: bool = false) throws { + ref va = vec.a; + var tmp = makeDistArray((...va.shape), bool); + select op { + when "<" { forall (t, xi) in zip(tmp, va) with (var local_val = val) do t = if !swapOperands then (xi < local_val) else (local_val < xi); } + when ">" { forall (t, xi) in zip(tmp, va) with (var local_val = val) do t = if !swapOperands then (xi > local_val) else (local_val > xi); } + when "<=" { forall (t, xi) in zip(tmp, va) with (var local_val = val) do t = if !swapOperands then (xi <= local_val) else (local_val <= xi); } + when ">=" { forall (t, xi) in zip(tmp, va) with (var local_val = val) do t = if !swapOperands then (xi >= local_val) else (local_val >= xi); } + when "==" { forall (t, xi) in zip(tmp, va) with (var local_val = val) do t = (xi == local_val); } + when "!=" { forall (t, xi) in zip(tmp, va) with (var local_val = val) do t = (xi != local_val); } + otherwise do throw new Error("Unsupported operation: " + (if !swapOperands then (vec.etype:string +" "+ op +" "+ val.type:string) + else (val.type:string +" "+ op +" "+ vec.etype:string))); + } + return tmp; + } + + proc doBigIntBinOpvs(l, val, op: string) throws { + return _doBigIntBinOpVLeftScalar(l, val, op, false); + } + + proc doBigIntBinOpsv(val, r, op: string) throws { + return _doBigIntBinOpVLeftScalar(r, val, op, true); + } + proc doBigIntBinOpvsBoolReturn(l, val, op: string) throws { - ref la = l.a; - var tmp = makeDistArray((...la.shape), bool); + return _doBigIntBinOpVLeftScalarBoolReturn(l, val, op, false); + } + + proc doBigIntBinOpsvBoolReturn(val, r, op: string) throws { + return _doBigIntBinOpVLeftScalarBoolReturn(r, val, op, true); + } + + proc doBigIntBinOpvsBoolReturnRealInput(const ref la: [?d] ?t1, ri: ?t2, op: string) throws + where ( (t1 == bigint && t2 == real(64)) || + (t1 == real(64) && t2 == bigint) ) + { + + var e = makeDistArray(d, bool); + ref ea = e; select op { - when "<" { - forall (t, li) in zip(tmp, la) with (var local_val = val) { - t = (li < local_val); + when "<" { + if t1 == bigint { + forall (ei, li) in zip(ea, la) do ei = ltBigReal(li, ri); + } else { // t2 == bigint + forall (ei, li) in zip(ea, la) do ei = gtBigReal(ri, li); // li ri>li } } - when ">" { - forall (t, li) in zip(tmp, la) with (var local_val = val) { - t = (li > local_val); + when ">" { + if t1 == bigint { + forall (ei, li) in zip(ea, la) do ei = gtBigReal(li, ri); + } else { + forall (ei, li) in zip(ea, la) do ei = ltBigReal(ri, li); } } when "<=" { - forall (t, li) in zip(tmp, la) with (var local_val = val) { - t = (li <= local_val); + if t1 == bigint { + forall (ei, li) in zip(ea, la) do ei = leBigReal(li, ri); + } else { + forall (ei, li) in zip(ea, la) do ei = geBigReal(ri, li); } } when ">=" { - forall (t, li) in zip(tmp, la) with (var local_val = val) { - t = (li >= local_val); + if t1 == bigint { + forall (ei, li) in zip(ea, la) do ei = geBigReal(li, ri); + } else { + forall (ei, li) in zip(ea, la) do ei = leBigReal(ri, li); } } when "==" { - forall (t, li) in zip(tmp, la) with (var local_val = val) { - t = (li == local_val); + if t1 == bigint { + forall (ei, li) in zip(ea, la) do ei = eqBigReal(li, ri); + } else { + forall (ei, li) in zip(ea, la) do ei = eqBigReal(ri, li); } } when "!=" { - forall (t, li) in zip(tmp, la) with (var local_val = val) { - t = (li != local_val); + if t1 == bigint { + forall (ei, li) in zip(ea, la) do ei = neBigReal(li, ri); + } else { + forall (ei, li) in zip(ea, la) do ei = neBigReal(ri, li); } } - otherwise { - // we should never reach this since we only enter this proc - // if boolOps.contains(op) - throw new Error("Unsupported operation: " +" "+ l.etype:string + op +" "+ val.type:string); - } + otherwise do + throw new Error("Unsupported operation: " + t1:string + " " + op + " " + t2:string); } - return tmp; + + return e; } - proc doBigIntBinOpsv(val, r, op: string) throws { - var max_bits = r.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; - if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - ref ra = r.a; - var tmp = makeDistArray((...ra.shape), bigint); - tmp = val:bigint; - // these cases are not mutually exclusive, - // so we have a flag to track if tmp is ever populated - var visted = false; + proc doBigIntBinOpvsBoolReturnRealInput(const ref la: [?d] ?t1, val: ?t2, op: string) throws + where ( (t1 != bigint || t2 != real(64)) && + (t1 != real(64) || t2 != bigint) ) + { + throw new Error("Unsupported operation: " + t1:string +" "+ op +" "+ t2:string); + } - // had to create bigint specific BinOp procs which return - // the distributed array because we need it at SymEntry creation time - if val.type == bigint && r.etype == bigint { - // first we try the ops that only work with - // both being bigint - select op { - when "&" { - forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { - t &= ri; - if has_max_bits { - t &= local_max_size; - } - } - visted = true; - } - when "|" { - forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { - t |= ri; - if has_max_bits { - t &= local_max_size; - } - } - visted = true; - } - when "^" { - forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { - t ^= ri; - if has_max_bits { - t &= local_max_size; - } - } - visted = true; - } - when "/" { - forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { - t /= ri; - if has_max_bits { - t &= local_max_size; - } - } - visted = true; - } - } - } - if val.type == bigint && (r.etype == bigint || r.etype == int || r.etype == uint) { - // then we try the ops that only work with a - // left hand side of bigint - if r.etype != bigint { - // can't shift a bigint by a bigint - select op { - when "<<" { - forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { - if has_max_bits { - if ri >= max_bits { - t = 0; - } - else { - t <<= ri; - t &= local_max_size; - } - } - else { - t <<= ri; - } - } - visted = true; - } - when ">>" { - forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { - if has_max_bits { - if ri >= max_bits { - t = 0; - } - else { - t >>= ri; - t &= local_max_size; - } - } - else { - t >>= ri; - } - } - visted = true; - } - when "<<<" { - if !has_max_bits { - throw new Error("Must set max_bits to rotl"); - } - var botBits = makeDistArray((...ra.shape), bigint); - botBits = val; - forall (t, ri, bot_bits) in zip(tmp, ra, botBits) with (var local_max_size = max_size) { - var modded_shift = if r.etype == int then ri % max_bits else ri % max_bits:uint; - t <<= modded_shift; - var shift_amt = if r.etype == int then max_bits - modded_shift else max_bits:uint - modded_shift; - bot_bits >>= shift_amt; - t += bot_bits; - t &= local_max_size; - } - visted = true; - } - when ">>>" { - if !has_max_bits { - throw new Error("Must set max_bits to rotr"); - } - var topBits = makeDistArray((...ra.shape), bigint); - topBits = val; - forall (t, ri, tB) in zip(tmp, ra, topBits) with (var local_max_size = max_size) { - var modded_shift = if r.etype == int then ri % max_bits else ri % max_bits:uint; - t >>= modded_shift; - var shift_amt = if r.etype == int then max_bits - modded_shift else max_bits:uint - modded_shift; - tB <<= shift_amt; - t += tB; - t &= local_max_size; - } - visted = true; - } - } + proc doBigIntBinOpvsRealReturn(l, val, op: string) throws { + select op { + when "+" { return l.a: real + val: real; } + when "-" { return l.a: real - val: real; } + when "*" { return l.a: real * val: real; } + when "/" { return l.a: real / val: real; } + when "**" { return l.a: real ** val: real; } + when "%" { + var e = makeDistArray((...l.tupShape), real); + ref ea = e; + ref la = l.a; + [(ei,li) in zip(ea,la)] ei = modHelper(li: real, val: real): real; + return e; } - select op { - when "//" { // floordiv - forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { - if ri != 0 { - t /= ri; - } - else { - t = 0:bigint; - } - if has_max_bits { - t &= local_max_size; - } - } - visted = true; - } - when "%" { // modulo " <- quote is workaround for syntax highlighter bug - forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { - if ri != 0 { - mod(t, t, ri); - } - else { - t = 0:bigint; - } - if has_max_bits { - t &= local_max_size; - } - } - visted = true; - } - when "**" { - if || reduce (ra<0) { - throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); - } - if has_max_bits { - forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { - powMod(t, t, ri, local_max_size + 1); - } - } - else { - forall (t, ri) in zip(tmp, ra) { - t **= ri:uint; - } - } - visted = true; - } + when "//" { + var e = makeDistArray((...l.tupShape), real); + ref ea = e; + ref la = l.a; + [(ei,li) in zip(ea,la)] ei = floorDivisionHelper(li: real, val: real): real; + return e; } - } - if (val.type == bigint && r.etype == bigint) || - (val.type == bigint && (r.etype == int || r.etype == uint || r.etype == bool)) || - (r.etype == bigint && (val.type == int || val.type == uint || val.type == bool)) { - select op { - when "+" { - forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { - t += ri; - if has_max_bits { - t &= local_max_size; - } - } - visted = true; - } - when "-" { - forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { - t -= ri; - if has_max_bits { - t &= local_max_size; - } - } - visted = true; - } - when "*" { - forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { - t *= ri; - if has_max_bits { - t &= local_max_size; - } - } - visted = true; - } + otherwise { + throw new Error("Unsupported operation: " + l.etype:string +" "+ op +" "+ val.type:string); } } - if !visted { - throw new Error("Unsupported operation: " + val.type:string +" "+ op +" "+ r.etype:string); - } - return (tmp, max_bits); } - proc doBigIntBinOpsvBoolReturn(val, r, op: string) throws { - ref ra = r.a; - var tmp = makeDistArray((...ra.shape), bool); + proc doBigIntBinOpsvBoolReturnRealInput(li: ?t1, const ref ra: [?d] ?t2, op: string) throws + where ( (t1 == bigint && t2 == real(64)) || + (t1 == real(64) && t2 == bigint) ) + { + + var e = makeDistArray(d, bool); + ref ea = e; select op { - when "<" { - forall (t, ri) in zip(tmp, ra) with (var local_val = val) { - t = (local_val < ri); + when "<" { + if t1 == bigint { + forall (ei, ri) in zip(ea, ra) do ei = ltBigReal(li, ri); + } else { // t2 == bigint + forall (ei, ri) in zip(ea, ra) do ei = gtBigReal(ri, li); // li ri>li } } - when ">" { - forall (t, ri) in zip(tmp, ra) with (var local_val = val) { - t = (local_val > ri); + when ">" { + if t1 == bigint { + forall (ei, ri) in zip(ea, ra) do ei = gtBigReal(li, ri); + } else { + forall (ei, ri) in zip(ea, ra) do ei = ltBigReal(ri, li); } } when "<=" { - forall (t, ri) in zip(tmp, ra) with (var local_val = val) { - t = (local_val <= ri); + if t1 == bigint { + forall (ei, ri) in zip(ea, ra) do ei = leBigReal(li, ri); + } else { + forall (ei, ri) in zip(ea, ra) do ei = geBigReal(ri, li); } } when ">=" { - forall (t, ri) in zip(tmp, ra) with (var local_val = val) { - t = (local_val >= ri); + if t1 == bigint { + forall (ei, ri) in zip(ea, ra) do ei = geBigReal(li, ri); + } else { + forall (ei, ri) in zip(ea, ra) do ei = leBigReal(ri, li); } } when "==" { - forall (t, ri) in zip(tmp, ra) with (var local_val = val) { - t = (local_val == ri); + if t1 == bigint { + forall (ei, ri) in zip(ea, ra) do ei = eqBigReal(li, ri); + } else { + forall (ei, ri) in zip(ea, ra) do ei = eqBigReal(ri, li); } } when "!=" { - forall (t, ri) in zip(tmp, ra) with (var local_val = val) { - t = (local_val != ri); + if t1 == bigint { + forall (ei, ri) in zip(ea, ra) do ei = neBigReal(li, ri); + } else { + forall (ei, ri) in zip(ea, ra) do ei = neBigReal(ri, li); } } + otherwise do + throw new Error("Unsupported operation: " + t1:string + " " + op + " " + t2:string); + } + + return e; + } + + proc doBigIntBinOpsvBoolReturnRealInput(li: ?t1, const ref ra: [?d] ?t2, op: string) throws + where ( (t1 != bigint || t2 != real(64)) && + (t1 != real(64) || t2 != bigint) ) + { + throw new Error("Unsupported operation: " + t1:string +" "+ op +" "+ t2:string); + } + + proc doBigIntBinOpsvRealReturn(l, r, op: string) throws { + select op { + when "+" { return l: real + r.a: real; } + when "-" { return l: real - r.a: real; } + when "*" { return l: real * r.a: real; } + when "/" { return l: real / r.a: real; } + when "**" { return l: real ** r.a: real; } + when "%" { + var e = makeDistArray((...r.tupShape), real); + ref ea = e; + ref ra = r.a; + [(ei,ri) in zip(ea,ra)] ei = modHelper(l: real, ri: real): real; + return e; + } + when "//" { + var e = makeDistArray((...r.tupShape), real); + ref ea = e; + ref ra = r.a; + [(ei,ri) in zip(ea,ra)] ei = floorDivisionHelper(l: real, ri: real): real; + return e; + } otherwise { - // we should never reach this since we only enter this proc - // if boolOps.contains(op) - throw new Error("Unsupported operation: " + val.type:string +" "+ op +" "+ r.etype:string); + throw new Error("Unsupported operation: " + l.type:string +" "+ op +" "+ r.etype:string); } } - return tmp; } -} + +} \ No newline at end of file diff --git a/src/OperatorMsg.chpl b/src/OperatorMsg.chpl index c43cbc95b38..1ca6467c460 100644 --- a/src/OperatorMsg.chpl +++ b/src/OperatorMsg.chpl @@ -148,24 +148,26 @@ module OperatorMsg "cmd: %? op: %? left pdarray: %? scalar: %?".format( cmd,op,st.attrib(msgArgs['a'].val), val)); - // This probably doesn't handle all normal bigint cases, but it handles a decent number. - // This, at least, can be expanded when BinOp.chpl is cleaned up - // It will be reasonably straightforward to clean up here. + // At this point it should handle almost every bigint case if (binop_dtype_a == bigint || binop_dtype_b == bigint) && - !isRealType(binop_dtype_a) && !isRealType(binop_dtype_b) + !isRealType(binop_dtype_a) && !isRealType(binop_dtype_b) && + op != '/' { if boolOps.contains(op) { - // call bigint specific func which returns distr bool array return st.insert(new shared SymEntry(doBigIntBinOpvsBoolReturn(l, val, op))); } // call bigint specific func which returns dist bigint array const (tmp, max_bits) = doBigIntBinOpvs(l, val, op); return st.insert(new shared SymEntry(tmp, max_bits)); - } else if (binop_dtype_a == bigint || binop_dtype_b == bigint) { - const errorMsg = unrecognizedTypeError(pn, "("+type2str(binop_dtype_a)+","+type2str(binop_dtype_b)+")"); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } else if (binop_dtype_a == bigint || binop_dtype_b == bigint) && + (isRealType(binop_dtype_a) || isRealType(binop_dtype_b)) && + boolOps.contains(op) { + // call bigint specific func which returns distr bool array + return st.insert(new shared SymEntry(doBigIntBinOpvsBoolReturnRealInput(l.a, val, op))); + } + else if (binop_dtype_a == bigint || binop_dtype_b == bigint) { + return st.insert(new shared SymEntry(doBigIntBinOpvsRealReturn(l, val, op))); } if boolOps.contains(op) { @@ -205,6 +207,7 @@ module OperatorMsg } return doBinOpvs(l, val, binop_dtype_a, binop_dtype_b, returnType, op, pn, st); + } /* @@ -236,24 +239,26 @@ module OperatorMsg "cmd: %? op = %? scalar dtype = %? scalar = %? pdarray = %?".format( cmd,op,type2str(binop_dtype_b),msgArgs['value'].val,st.attrib(msgArgs['a'].val))); - // This probably doesn't handle all normal bigint cases, but it handles a decent number. - // This, at least, can be expanded when BinOp.chpl is cleaned up - // It will be reasonably straightforward to clean up here. + // At this point it should handle almost every bigint case if (binop_dtype_a == bigint || binop_dtype_b == bigint) && - !isRealType(binop_dtype_a) && !isRealType(binop_dtype_b) + !isRealType(binop_dtype_a) && !isRealType(binop_dtype_b) && + op != '/' { if boolOps.contains(op) { - // call bigint specific func which returns distr bool array return st.insert(new shared SymEntry(doBigIntBinOpsvBoolReturn(val, r, op))); } // call bigint specific func which returns dist bigint array const (tmp, max_bits) = doBigIntBinOpsv(val, r, op); return st.insert(new shared SymEntry(tmp, max_bits)); - } else if (binop_dtype_a == bigint || binop_dtype_b == bigint) { - const errorMsg = unrecognizedTypeError(pn, "("+type2str(binop_dtype_a)+","+type2str(binop_dtype_b)+")"); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } else if (binop_dtype_a == bigint || binop_dtype_b == bigint) && + (isRealType(binop_dtype_a) || isRealType(binop_dtype_b)) && + boolOps.contains(op) { + // call bigint specific func which returns distr bool array + return st.insert(new shared SymEntry(doBigIntBinOpsvBoolReturnRealInput(val, r.a, op))); + } + else if (binop_dtype_a == bigint || binop_dtype_b == bigint) { + return st.insert(new shared SymEntry(doBigIntBinOpsvRealReturn(val, r, op))); } if boolOps.contains(op) { @@ -293,8 +298,82 @@ module OperatorMsg } return doBinOpsv(val, r, binop_dtype_a, binop_dtype_b, returnType, op, pn, st); + + } + + // --- NumPy-ish casting helpers for in-place ufunc semantics (casting='same_kind') --- + // These are because casting is a little different for things like += (vs. +) + + proc isArithInplaceOp(op: string): bool { + return op == "+=" || op == "-=" || op == "*=" || + op == "//=" || op == "%=" || op == "**="; } + proc isBitInplaceOp(op: string): bool { + return op == "&=" || op == "|=" || op == "^=" || + op == "<<=" || op == ">>="; + } + + proc numpyLikeOpeqGate(type lhsT, type rhsT, op: string): int { + const isArith = isArithInplaceOp(op); + const isBit = isBitInplaceOp(op); + if !isArith && !isBit then return 3; + + param kL = splitType(lhsT); + param kR = splitType(rhsT); + + // Float LHS: arithmetic ok, bitwise/shifts TypeError; float op= bigint => UFuncTypeError + if lhsT == real(64) { + if kR == 4 then return 2; + return if isBit then 1 else 0; + } + + // Bool LHS special-case + if lhsT == bool { + if rhsT != bool then return 2; + // bool op= bool: + if op == "-=" then return 1; + if op == "+=" || op == "*=" || op == "&=" || op == "|=" || op == "^=" then return 0; + return 2; // //= %= **= <<= >>= are UFuncTypeError in your matrix + } + + // Non-bigint LHS with bigint RHS => UFuncTypeError (cannot cast object/bigint result back) + if kL != 4 && kR == 4 { + return 2; + } + + // Bigint LHS rules + if kL == 4 { + // allow **= with int/uint/bool/bigint on rhs. That's integer-like. + // For float rhs: cannot keep bigint dtype => TypeError + if kR == 3 then return 1; + return 0; + } + + // Now we are in {int64,uint64,uint8} LHS cases. + // Mixed int64/uint64: + if lhsT == int(64) && rhsT == uint(64) { + return if isBit then 1 else 2; + } + if lhsT == uint(64) && rhsT == int(64) { + return if isBit then 1 else 2; + } + + // uint8 op= int64 => all UFuncTypeError + if lhsT == uint(8) && rhsT == int(64) { + return 2; + } + + // integer LHS with float RHS => arithmetic UFuncTypeError, bit TypeError + if kR == 3 { + return if isBit then 1 else 2; + } + + // Otherwise allow (covers i64 op= i64/u8/b; u64 op= u64/u8/b; u8 op= u8/u64/b) + return 0; + } + + /* Parse and respond to opeqvv message. vector op= vector @@ -325,6 +404,163 @@ module OperatorMsg "cmd: %s op: %s left pdarray: %s right pdarray: %s".format(cmd,op, st.attrib(msgArgs['a'].val),st.attrib(msgArgs['b'].val))); + const gate = numpyLikeOpeqGate(binop_dtype_a, binop_dtype_b, op); + param kL = splitType(binop_dtype_a); +param kR = splitType(binop_dtype_b); +ref la = l.a; + select gate { + when 0 { + +// ---- bool/bool special-case (NumPy quirks) ---- + +if kL == 0 && kR == 0 { + select op { + when "+=" { l.a = l.a | r.a; return MsgTuple.success(); } + when "*=" { l.a = l.a & r.a; return MsgTuple.success(); } + when "&=" { l.a &= r.a; return MsgTuple.success(); } + when "|=" { l.a |= r.a; return MsgTuple.success(); } + when "^=" { l.a ^= r.a; return MsgTuple.success(); } + when "-=" { return new MsgTuple("TypeError", MsgType.ERROR); } + otherwise { return new MsgTuple("TypeError", MsgType.ERROR); } + } +} + +// If we are instantiated with bool LHS, the only supported RHS is bool, +// and that case returned above. Prevent the compiler from typechecking +// the generic arithmetic/bitwise code for bool LHS instantiations. +if kL == 0 { + // matches numpyLikeOpeqGate for bool with non-bool RHS + return new MsgTuple("UFuncTypeError", MsgType.ERROR); +} + + // ---- general path (gate==0 and not bool/bool) ---- + // If instantiated with bigint LHS and real RHS, we don't support it (and we must +// prevent the compiler from typechecking casts from real->bigint). +if kL == 4 && kR == 3 { + return new MsgTuple("TypeError", MsgType.ERROR); +} +const ra = r.a; +var handled = true; + +// splitType: 0 bool, 1 uint, 2 int, 3 real, 4 bigint + +if isArithInplaceOp(op) { + + if op == "+=" { + la += (ra: binop_dtype_a); + + } else if op == "-=" { + la -= (ra: binop_dtype_a); + + } else if op == "*=" { + la *= (ra: binop_dtype_a); + + } else if op == "**=" { + if binop_dtype_a == int(64) { + if || reduce ((r.a: int(64)) < 0) { + return new MsgTuple( + "Attempt to exponentiate base of type Int64 to negative exponent", + MsgType.ERROR + ); + } + } + // la **= (ra: binop_dtype_a); + + if kL == 4 && l.max_bits != -1 { + const max_size = (1: bigint << l.max_bits); + forall (t, ri) in zip(la, ra) with (var local_max_size = max_size) { + powMod(t, t, ri, max_size); + } + } + else { + try { + forall (t, ri) in zip(la, ra) do t **= ri:binop_dtype_a; + } catch { + return new MsgTuple ( + "Exponentiation too large; use smaller values or set max_bits", + MsgType.ERROR + ); + } + } + + } else if op == "//=" { + + if kL == 3 && binop_dtype_a == real(64) { + // NumPy-like float floor-division + const rb = (r.a: real(64)); + [(li, ri) in zip(la, rb)] li = floorDivisionHelper(li, ri); + + } else { + // integer/bool/uint/bigint style, preserve your div-by-zero->0 behavior + ref la2 = l.a; + const rb = (r.a: binop_dtype_a); + [(li, ri) in zip(la2, rb)] li = if ri != 0 then li/ri else (0: binop_dtype_a); + } + + } else if op == "%=" { + + if kL == 3 && binop_dtype_a == real(64) { + // NumPy-like float modulo + const rb = (r.a: real(64)); + [(li, ri) in zip(la, rb)] li = modHelper(li, ri); + + } else { + // integer/bool/uint/bigint modulo, preserve div-by-zero->0 behavior if you want + ref la2 = l.a; + const rb = (r.a: binop_dtype_a); + [(li, ri) in zip(la2, rb)] li = if ri != 0 then li%ri else (0: binop_dtype_a); + } + + } else { + handled = false; + } + +} else if isBitInplaceOp(op) { + + // Bitwise + shifts must NOT compile for real LHS + if (kL == 0 || kL == 1 || kL == 2) { + select op { + when ">>=" { la >>= (ra: binop_dtype_a); } + when "<<=" { la <<= (ra: binop_dtype_a); } + when "&=" { la &= (ra: binop_dtype_a); } + when "|=" { la |= (ra: binop_dtype_a); } + when "^=" { la ^= (ra: binop_dtype_a); } + otherwise { handled = false; } + } + } else if (kL == 0 || kL == 1 || kL == 2 || kL == 4) { +select op { + when "&=" { la &= (ra: binop_dtype_a); } + when "|=" { la |= (ra: binop_dtype_a); } + when "^=" { la ^= (ra: binop_dtype_a); } + otherwise { handled = false; } + } + } else { + // real LHS should not reach here; gate should have blocked it + return new MsgTuple("TypeError", MsgType.ERROR); + } + +} else { + handled = false; +} + +if !handled then return MsgTuple.error(nie); + if kL == 4 { + const mask = (1: bigint << l.max_bits) - 1; + la &= mask; + } +return MsgTuple.success(); + + + } + when 1 { return new MsgTuple("TypeError", MsgType.ERROR); } + when 2 { return new MsgTuple("TypeError", MsgType.ERROR); } // Technically numpy views these + // as two different kinds of + // TypeError + otherwise { return MsgTuple.error(nie); } + } + + return MsgTuple.success(); + if binop_dtype_a == int && binop_dtype_b == int { select op { when "+=" { l.a += r.a; } From 06f4c482d344fa24d6ce9adb1a3c56ebff37457a Mon Sep 17 00:00:00 2001 From: Ryan Keck Date: Mon, 5 Jan 2026 14:46:29 -0500 Subject: [PATCH 2/4] Fixes alignment and adds /= --- src/OperatorMsg.chpl | 734 ++++++++++--------------------------------- 1 file changed, 160 insertions(+), 574 deletions(-) diff --git a/src/OperatorMsg.chpl b/src/OperatorMsg.chpl index 1ca6467c460..46e4f07ef3e 100644 --- a/src/OperatorMsg.chpl +++ b/src/OperatorMsg.chpl @@ -305,7 +305,7 @@ module OperatorMsg // These are because casting is a little different for things like += (vs. +) proc isArithInplaceOp(op: string): bool { - return op == "+=" || op == "-=" || op == "*=" || + return op == "+=" || op == "-=" || op == "*=" || op == "/=" || op == "//=" || op == "%=" || op == "**="; } @@ -322,6 +322,21 @@ module OperatorMsg param kL = splitType(lhsT); param kR = splitType(rhsT); + // --- True division (/=) special rule (NumPy in-place semantics) --- + // Only float LHS supports /= in-place in an array. + // (Integers/bool fail because result is float; bigint-as-bigint can't hold float.) + if op == "/=" { + if lhsT == real(64) { + // float LHS: allow unless RHS is bigint (you already treat that as UFuncTypeError) + if kR == 4 then return 2; + // RHS bool/int/uint/real are fine + return 0; + } else { + // int/uint/bool/bigint LHS: in-place true divide rejects (NumPy UFuncTypeError) + return 2; + } + } + // Float LHS: arithmetic ok, bitwise/shifts TypeError; float op= bigint => UFuncTypeError if lhsT == real(64) { if kR == 4 then return 2; @@ -393,609 +408,180 @@ module OperatorMsg type binop_dtype_b, param array_nd: int ): MsgTuple throws { - param pn = Reflection.getRoutineName(); + param pn = Reflection.getRoutineName(); - var l = st[msgArgs['a']]: borrowed SymEntry(binop_dtype_a, array_nd); - const r = st[msgArgs['b']]: borrowed SymEntry(binop_dtype_b, array_nd), - op = msgArgs['op'].toScalar(string), - nie = notImplementedError(pn,type2str(binop_dtype_a),op,type2str(binop_dtype_b)); + var l = st[msgArgs['a']]: borrowed SymEntry(binop_dtype_a, array_nd); + const r = st[msgArgs['b']]: borrowed SymEntry(binop_dtype_b, array_nd), + op = msgArgs['op'].toScalar(string), + nie = notImplementedError(pn,type2str(binop_dtype_a),op,type2str(binop_dtype_b)); - omLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), - "cmd: %s op: %s left pdarray: %s right pdarray: %s".format(cmd,op, - st.attrib(msgArgs['a'].val),st.attrib(msgArgs['b'].val))); - - const gate = numpyLikeOpeqGate(binop_dtype_a, binop_dtype_b, op); - param kL = splitType(binop_dtype_a); -param kR = splitType(binop_dtype_b); -ref la = l.a; - select gate { - when 0 { - -// ---- bool/bool special-case (NumPy quirks) ---- - -if kL == 0 && kR == 0 { - select op { - when "+=" { l.a = l.a | r.a; return MsgTuple.success(); } - when "*=" { l.a = l.a & r.a; return MsgTuple.success(); } - when "&=" { l.a &= r.a; return MsgTuple.success(); } - when "|=" { l.a |= r.a; return MsgTuple.success(); } - when "^=" { l.a ^= r.a; return MsgTuple.success(); } - when "-=" { return new MsgTuple("TypeError", MsgType.ERROR); } - otherwise { return new MsgTuple("TypeError", MsgType.ERROR); } - } -} + omLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), + "cmd: %s op: %s left pdarray: %s right pdarray: %s".format(cmd,op, + st.attrib(msgArgs['a'].val),st.attrib(msgArgs['b'].val))); -// If we are instantiated with bool LHS, the only supported RHS is bool, -// and that case returned above. Prevent the compiler from typechecking -// the generic arithmetic/bitwise code for bool LHS instantiations. -if kL == 0 { - // matches numpyLikeOpeqGate for bool with non-bool RHS - return new MsgTuple("UFuncTypeError", MsgType.ERROR); -} + const gate = numpyLikeOpeqGate(binop_dtype_a, binop_dtype_b, op); + param kL = splitType(binop_dtype_a); + param kR = splitType(binop_dtype_b); + ref la = l.a; + select gate { + when 0 { - // ---- general path (gate==0 and not bool/bool) ---- - // If instantiated with bigint LHS and real RHS, we don't support it (and we must -// prevent the compiler from typechecking casts from real->bigint). -if kL == 4 && kR == 3 { - return new MsgTuple("TypeError", MsgType.ERROR); -} -const ra = r.a; -var handled = true; + // ---- bool/bool special-case (NumPy quirks) ---- -// splitType: 0 bool, 1 uint, 2 int, 3 real, 4 bigint - -if isArithInplaceOp(op) { - - if op == "+=" { - la += (ra: binop_dtype_a); - - } else if op == "-=" { - la -= (ra: binop_dtype_a); - - } else if op == "*=" { - la *= (ra: binop_dtype_a); - - } else if op == "**=" { - if binop_dtype_a == int(64) { - if || reduce ((r.a: int(64)) < 0) { - return new MsgTuple( - "Attempt to exponentiate base of type Int64 to negative exponent", - MsgType.ERROR - ); - } - } - // la **= (ra: binop_dtype_a); - - if kL == 4 && l.max_bits != -1 { - const max_size = (1: bigint << l.max_bits); - forall (t, ri) in zip(la, ra) with (var local_max_size = max_size) { - powMod(t, t, ri, max_size); + if kL == 0 && kR == 0 { + select op { + when "+=" { l.a = l.a | r.a; return MsgTuple.success(); } + when "*=" { l.a = l.a & r.a; return MsgTuple.success(); } + when "&=" { l.a &= r.a; return MsgTuple.success(); } + when "|=" { l.a |= r.a; return MsgTuple.success(); } + when "^=" { l.a ^= r.a; return MsgTuple.success(); } + when "-=" { return new MsgTuple("TypeError", MsgType.ERROR); } + otherwise { return new MsgTuple("TypeError", MsgType.ERROR); } } } - else { - try { - forall (t, ri) in zip(la, ra) do t **= ri:binop_dtype_a; - } catch { - return new MsgTuple ( - "Exponentiation too large; use smaller values or set max_bits", - MsgType.ERROR - ); - } + + // If we are instantiated with bool LHS, the only supported RHS is bool, + // and that case returned above. Prevent the compiler from typechecking + // the generic arithmetic/bitwise code for bool LHS instantiations. + if kL == 0 { + // matches numpyLikeOpeqGate for bool with non-bool RHS + return new MsgTuple("TypeError", MsgType.ERROR); } - } else if op == "//=" { + // ---- general path (gate==0 and not bool/bool) ---- + // If instantiated with bigint LHS and real RHS, we don't support it (and we must + // prevent the compiler from typechecking casts from real->bigint). + if kL == 4 && kR == 3 { + return new MsgTuple("TypeError", MsgType.ERROR); + } + const ra = r.a; + var handled = true; - if kL == 3 && binop_dtype_a == real(64) { - // NumPy-like float floor-division - const rb = (r.a: real(64)); - [(li, ri) in zip(la, rb)] li = floorDivisionHelper(li, ri); + // splitType: 0 bool, 1 uint, 2 int, 3 real, 4 bigint - } else { - // integer/bool/uint/bigint style, preserve your div-by-zero->0 behavior - ref la2 = l.a; - const rb = (r.a: binop_dtype_a); - [(li, ri) in zip(la2, rb)] li = if ri != 0 then li/ri else (0: binop_dtype_a); - } + if isArithInplaceOp(op) { - } else if op == "%=" { + if op == "+=" { + la += (ra: binop_dtype_a); - if kL == 3 && binop_dtype_a == real(64) { - // NumPy-like float modulo - const rb = (r.a: real(64)); - [(li, ri) in zip(la, rb)] li = modHelper(li, ri); + } else if op == "-=" { + la -= (ra: binop_dtype_a); - } else { - // integer/bool/uint/bigint modulo, preserve div-by-zero->0 behavior if you want - ref la2 = l.a; - const rb = (r.a: binop_dtype_a); - [(li, ri) in zip(la2, rb)] li = if ri != 0 then li%ri else (0: binop_dtype_a); - } + } else if op == "*=" { + la *= (ra: binop_dtype_a); - } else { - handled = false; - } - -} else if isBitInplaceOp(op) { - - // Bitwise + shifts must NOT compile for real LHS - if (kL == 0 || kL == 1 || kL == 2) { - select op { - when ">>=" { la >>= (ra: binop_dtype_a); } - when "<<=" { la <<= (ra: binop_dtype_a); } - when "&=" { la &= (ra: binop_dtype_a); } - when "|=" { la |= (ra: binop_dtype_a); } - when "^=" { la ^= (ra: binop_dtype_a); } - otherwise { handled = false; } - } - } else if (kL == 0 || kL == 1 || kL == 2 || kL == 4) { -select op { - when "&=" { la &= (ra: binop_dtype_a); } - when "|=" { la |= (ra: binop_dtype_a); } - when "^=" { la ^= (ra: binop_dtype_a); } - otherwise { handled = false; } - } - } else { - // real LHS should not reach here; gate should have blocked it - return new MsgTuple("TypeError", MsgType.ERROR); - } + } else if op == "**=" { + if binop_dtype_a == int(64) { + if || reduce ((r.a: int(64)) < 0) { + return new MsgTuple( + "Attempt to exponentiate base of type Int64 to negative exponent", + MsgType.ERROR + ); + } + } + // la **= (ra: binop_dtype_a); + + if kL == 4 && l.max_bits != -1 { + const max_size = (1: bigint << l.max_bits); + forall (t, ri) in zip(la, ra) with (var local_max_size = max_size) { + powMod(t, t, ri, max_size); + } + } else { + try { + forall (t, ri) in zip(la, ra) do t **= ri:binop_dtype_a; + } catch { + return new MsgTuple ( + "Exponentiation too large; use smaller values or set max_bits", + MsgType.ERROR + ); + } + } -} else { - handled = false; -} + } else if op == "//=" { -if !handled then return MsgTuple.error(nie); - if kL == 4 { - const mask = (1: bigint << l.max_bits) - 1; - la &= mask; - } -return MsgTuple.success(); + if kL == 3 && binop_dtype_a == real(64) { + // NumPy-like float floor-division + const rb = (r.a: real(64)); + [(li, ri) in zip(la, rb)] li = floorDivisionHelper(li, ri); + } else { + // integer/bool/uint/bigint style, preserve div-by-zero->0 behavior + ref la2 = l.a; + const rb = (r.a: binop_dtype_a); + [(li, ri) in zip(la2, rb)] li = if ri != 0 then li/ri else (0: binop_dtype_a); + } - } - when 1 { return new MsgTuple("TypeError", MsgType.ERROR); } - when 2 { return new MsgTuple("TypeError", MsgType.ERROR); } // Technically numpy views these - // as two different kinds of - // TypeError - otherwise { return MsgTuple.error(nie); } - } + } else if op == "%=" { - return MsgTuple.success(); + if kL == 3 && binop_dtype_a == real(64) { + // NumPy-like float modulo + const rb = (r.a: real(64)); + [(li, ri) in zip(la, rb)] li = modHelper(li, ri); - if binop_dtype_a == int && binop_dtype_b == int { - select op { - when "+=" { l.a += r.a; } - when "-=" { l.a -= r.a; } - when "*=" { l.a *= r.a; } - when ">>=" { l.a >>= r.a;} - when "<<=" { l.a <<= r.a;} - when "//=" { - //l.a /= r.a; - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = if ri != 0 then li/ri else 0; - }//floordiv - when "%=" { - //l.a /= r.a; - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = if ri != 0 then li%ri else 0; - } - when "**=" { - if || reduce (r.a<0){ - var errorMsg = "Attempt to exponentiate base of type Int64 to negative exponent"; - return new MsgTuple(errorMsg, MsgType.ERROR); - } - else{ l.a **= r.a; } - } - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == int && binop_dtype_b == bool { - select op { - when "+=" {l.a += r.a:int;} - when "-=" {l.a -= r.a:int;} - when "*=" {l.a *= r.a:int;} - when ">>=" { l.a >>= r.a:int;} - when "<<=" { l.a <<= r.a:int;} - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == uint && binop_dtype_b == uint { - select op { - when "+=" { l.a += r.a; } - when "-=" { - l.a -= r.a; - } - when "*=" { l.a *= r.a; } - when "//=" { - //l.a /= r.a; - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = if ri != 0 then li/ri else 0; - }//floordiv - when "%=" { - //l.a /= r.a; - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = if ri != 0 then li%ri else 0; - } - when "**=" { - l.a **= r.a; - } - when ">>=" { l.a >>= r.a;} - when "<<=" { l.a <<= r.a;} - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == uint && binop_dtype_b == bool { - select op { - when "+=" {l.a += r.a:uint;} - when "-=" {l.a -= r.a:uint;} - when "*=" {l.a *= r.a:uint;} - when ">>=" { l.a >>= r.a:uint;} - when "<<=" { l.a <<= r.a:uint;} - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == real && binop_dtype_b == int { - select op { - when "+=" {l.a += r.a;} - when "-=" {l.a -= r.a;} - when "*=" {l.a *= r.a;} - when "/=" {l.a /= r.a:real;} //truediv - when "//=" { //floordiv - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = floorDivisionHelper(li, ri); - } - when "**=" { l.a **= r.a; } - when "%=" { - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = modHelper(li, ri); - } - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == real && binop_dtype_b == uint { - select op { - when "+=" {l.a += r.a;} - when "-=" {l.a -= r.a;} - when "*=" {l.a *= r.a;} - when "/=" {l.a /= r.a:real;} //truediv - when "//=" { //floordiv - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = floorDivisionHelper(li, ri); - } - when "**=" { l.a **= r.a; } - when "%=" { - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = modHelper(li, ri); - } - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == real && binop_dtype_b == real { - select op { - when "+=" {l.a += r.a;} - when "-=" {l.a -= r.a;} - when "*=" {l.a *= r.a;} - when "/=" {l.a /= r.a;}//truediv - when "//=" { //floordiv - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = floorDivisionHelper(li, ri); - } - when "**=" { l.a **= r.a; } - when "%=" { - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = modHelper(li, ri); - } - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == real && binop_dtype_b == bool { - select op { - when "+=" {l.a += r.a:real;} - when "-=" {l.a -= r.a:real;} - when "*=" {l.a *= r.a:real;} - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == bool && binop_dtype_b == bool { - select op { - when "|=" {l.a |= r.a;} - when "&=" {l.a &= r.a;} - when "^=" {l.a ^= r.a;} - when "+=" {l.a |= r.a;} - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == bigint && binop_dtype_b == int { - ref la = l.a; - ref ra = r.a; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; - if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - select op { - when "+=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li += ri; - if has_max_bits { - li &= local_max_size; - } - } + } else { + // integer/bool/uint/bigint modulo, preserve div-by-zero->0 behavior + ref la2 = l.a; + const rb = (r.a: binop_dtype_a); + [(li, ri) in zip(la2, rb)] li = if ri != 0 then li%ri else (0: binop_dtype_a); } - when "-=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li -= ri; - if has_max_bits { - li &= local_max_size; - } - } - } - when "*=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li *= ri; - if has_max_bits { - li &= local_max_size; - } - } - } - when "//=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - if ri != 0 { - li /= ri; - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } - } - when "%=" { - // we can't use li %= ri because this can result in negatives - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - if ri != 0 { - mod(li, li, ri); - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } - } - when "**=" { - if || reduce (ra<0) { - throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); - } - if has_max_bits { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - powMod(li, li, ri, local_max_size + 1); - } - } - else { - forall (li, ri) in zip(la, ra) { - li **= ri:uint; - } - } - } - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == bigint && binop_dtype_b == uint { - ref la = l.a; - ref ra = r.a; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; - if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - select op { - when "+=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li += ri; - if has_max_bits { - li &= local_max_size; - } - } - } - when "-=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li -= ri; - if has_max_bits { - li &= local_max_size; - } - } - } - when "*=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li *= ri; - if has_max_bits { - li &= local_max_size; - } - } - } - when "//=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - if ri != 0 { - li /= ri; - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } - } - when "%=" { - // we can't use li %= ri because this can result in negatives - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - if ri != 0 { - mod(li, li, ri); - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } - } - when "**=" { - if || reduce (ra<0) { - throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); - } - if has_max_bits { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - powMod(li, li, ri, local_max_size + 1); - } - } - else { - forall (li, ri) in zip(la, ra) { - li **= ri:uint; - } - } - } - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == bigint && binop_dtype_b == bool { - ref la = l.a; - var ra = r.a:bigint; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; - if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - select op { - when "+=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li += ri; - if has_max_bits { - li &= local_max_size; - } - } - } - when "-=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li -= ri; - if has_max_bits { - li &= local_max_size; - } - } - } - when "*=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li *= ri; - if has_max_bits { - li &= local_max_size; - } - } + + } else if op == "/=" { + if kL == 3 && binop_dtype_a == real(64) { + // Only float LHS should reach here due to the gate. + // NumPy behavior for float division by zero is inf/nan (with warnings). + const rb = (r.a: real(64)); + [(li, ri) in zip(la, rb)] li = li / ri; } - otherwise do return MsgTuple.error(nie); + } else { + handled = false; } - } - else if binop_dtype_a == bigint && binop_dtype_b == bigint { - ref la = l.a; - ref ra = r.a; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; - if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - select op { - when "+=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li += ri; - if has_max_bits { - li &= local_max_size; - } - } - } - when "-=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li -= ri; - if has_max_bits { - li &= local_max_size; - } - } - } - when "*=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li *= ri; - if has_max_bits { - li &= local_max_size; - } - } - } - when "//=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - if ri != 0 { - li /= ri; - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } - } - when "%=" { - // we can't use li %= ri because this can result in negatives - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - if ri != 0 { - mod(li, li, ri); - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + + } else if isBitInplaceOp(op) { + + // Bitwise + shifts must NOT compile for real LHS + if (kL == 0 || kL == 1 || kL == 2) { + select op { + when ">>=" { la >>= (ra: binop_dtype_a); } + when "<<=" { la <<= (ra: binop_dtype_a); } + when "&=" { la &= (ra: binop_dtype_a); } + when "|=" { la |= (ra: binop_dtype_a); } + when "^=" { la ^= (ra: binop_dtype_a); } + otherwise { handled = false; } } - when "**=" { - if || reduce (ra<0) { - throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); - } - if has_max_bits { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - powMod(li, li, ri, local_max_size + 1); - } - } - else { - forall (li, ri) in zip(la, ra) { - li **= ri:uint; - } - } + } else if (kL == 0 || kL == 1 || kL == 2 || kL == 4) { + select op { + when "&=" { la &= (ra: binop_dtype_a); } + when "|=" { la |= (ra: binop_dtype_a); } + when "^=" { la ^= (ra: binop_dtype_a); } + otherwise { handled = false; } } - otherwise do return MsgTuple.error(nie); + } else { + // real LHS should not reach here; gate should have blocked it + return new MsgTuple("TypeError", MsgType.ERROR); } + } else { - return MsgTuple.error(nie); + handled = false; } - return MsgTuple.success(); + if !handled then return MsgTuple.error(nie); + if kL == 4 { + const mask = (1: bigint << l.max_bits) - 1; + la &= mask; + } + return MsgTuple.success(); + + + } + when 1 { return new MsgTuple("TypeError", MsgType.ERROR); } + when 2 { return new MsgTuple("TypeError", MsgType.ERROR); } // Technically numpy views these + // as two different kinds of + // TypeError + otherwise { return MsgTuple.error(nie); } + } + + return MsgTuple.success(); + } /* From cb0222af06ed888b90b2dcdc4532bb2ab27bc51e Mon Sep 17 00:00:00 2001 From: Ryan Keck Date: Wed, 7 Jan 2026 12:45:13 -0500 Subject: [PATCH 3/4] Made some progress --- src/BinOp.chpl | 2 +- src/OperatorMsg.chpl | 322 +++++++++++++++++++++++-------------------- 2 files changed, 177 insertions(+), 147 deletions(-) diff --git a/src/BinOp.chpl b/src/BinOp.chpl index 65c1c42612a..132fd6bfa9c 100644 --- a/src/BinOp.chpl +++ b/src/BinOp.chpl @@ -1291,4 +1291,4 @@ module BinOp } } -} \ No newline at end of file +} diff --git a/src/OperatorMsg.chpl b/src/OperatorMsg.chpl index 46e4f07ef3e..51d5b276c02 100644 --- a/src/OperatorMsg.chpl +++ b/src/OperatorMsg.chpl @@ -602,175 +602,205 @@ module OperatorMsg */ @arkouda.instantiateAndRegister proc opeqvs(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, - type binop_dtype_a, - param array_nd: int + type binop_dtype_a, + param array_nd: int ): MsgTuple throws { - param pn = Reflection.getRoutineName(); + param pn = Reflection.getRoutineName(); - // b is always the same type as a - type binop_dtype_b = binop_dtype_a; + // RHS is always the same type as LHS (typed scalar semantics) + type binop_dtype_b = binop_dtype_a; - var l = st[msgArgs['a']]: borrowed SymEntry(binop_dtype_a, array_nd); - const val = msgArgs['value'].toScalar(binop_dtype_b), - op = msgArgs['op'].toScalar(string), - nie = notImplementedError(pn,type2str(binop_dtype_a),op,type2str(binop_dtype_b)); + var l = st[msgArgs["a"]]: borrowed SymEntry(binop_dtype_a, array_nd); + const val = msgArgs["value"].toScalar(binop_dtype_b), + op = msgArgs["op"].toScalar(string), + nie = notImplementedError(pn, type2str(binop_dtype_a), op, type2str(binop_dtype_b)); - omLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), - "op: %? pdarray: %? scalar: %?".format(op,st.attrib(msgArgs['a'].val),val)); + omLogger.debug(getModuleName(), getRoutineName(), getLineNumber(), + "op: %? pdarray: %? scalar: %?".format(op, st.attrib(msgArgs["a"].val), val)); - if binop_dtype_a == int && binop_dtype_b == int { - select op { - when "+=" { l.a += val; } - when "-=" { l.a -= val; } - when "*=" { l.a *= val; } - when ">>=" { l.a >>= val; } - when "<<=" { l.a <<= val; } - when "//=" { - if val != 0 {l.a /= val;} else {l.a = 0;} - }//floordiv - when "%=" { - if val != 0 {l.a %= val;} else {l.a = 0;} - } - when "**=" { - if val<0 { - var errorMsg = "Attempt to exponentiate base of type int64 to negative exponent"; - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(), - errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - else{ l.a **= val; } + // Keep the same gate already built. Here lhsT == rhsT, so it will mostly: + // - block float bitwise/shift + // - enforce bool quirks (-=, etc.) + // - block int/uint /= (true-div) etc. + const gate = numpyLikeOpeqGate(binop_dtype_a, binop_dtype_b, op); + param kL = splitType(binop_dtype_a); // 0 bool, 1 uint, 2 int, 3 real, 4 bigint - } - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == uint && binop_dtype_b == uint { - select op { - when "+=" { l.a += val; } - when "-=" { - l.a -= val; - } - when "*=" { l.a *= val; } - when "//=" { - if val != 0 {l.a /= val;} else {l.a = 0;} - }//floordiv - when "%=" { - if val != 0 {l.a %= val;} else {l.a = 0;} - } - when "**=" { - l.a **= val; - } - when ">>=" { l.a >>= val; } - when "<<=" { l.a <<= val; } - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == bool && binop_dtype_b == bool { - select op { - when "+=" {l.a |= val;} - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == real && binop_dtype_b == real { + ref la = l.a; + + select gate { + when 0 { + + // ----------------------------- + // bool/bool special-case (NumPy quirks) + // ----------------------------- + if kL == 0 { + // Only bool scalar is possible due to binop_dtype_b == binop_dtype_a, + // but keep this explicit and mirror opeqvv. select op { - when "+=" {l.a += val;} - when "-=" {l.a -= val;} - when "*=" {l.a *= val;} - when "/=" {l.a /= val;}//truediv - when "//=" { //floordiv - ref la = l.a; - [li in la] li = floorDivisionHelper(li, val); - } - when "**=" { l.a **= val; } - when "%=" { - ref la = l.a; - [li in la] li = modHelper(li, val); - } - otherwise do return MsgTuple.error(nie); - } - } - else if binop_dtype_a == bigint && binop_dtype_b == bigint { - ref la = l.a; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; - if has_max_bits { - max_size <<= max_bits; - max_size -= 1; + when "+=" { la |= val; return MsgTuple.success(); } // bool += bool -> OR + when "*=" { la &= val; return MsgTuple.success(); } // bool *= bool -> AND + when "&=" { la &= val; return MsgTuple.success(); } + when "|=" { la |= val; return MsgTuple.success(); } + when "^=" { la ^= val; return MsgTuple.success(); } + when "-=" { return new MsgTuple("TypeError", MsgType.ERROR); } + otherwise { return new MsgTuple("TypeError", MsgType.ERROR); } } - select op { - when "+=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - li += local_val; - if has_max_bits { - li &= local_max_size; - } - } + } + + var handled = true; + + // ----------------------------- + // arithmetic inplace ops + // ----------------------------- + if isArithInplaceOp(op) { + + if op == "+=" { + la += val; + + } else if op == "-=" { + la -= val; + + } else if op == "*=" { + la *= val; + + } else if op == "/=" { + // With rhs==lhs, gate should only allow real(64) here. + if kL == 3 && binop_dtype_a == real(64) { + // NumPy-like: inf/nan on div-by-zero (warnings at Python layer, if any) + la /= val; + } else { + return new MsgTuple("TypeError", MsgType.ERROR); } - when "-=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - li -= local_val; - if has_max_bits { - li &= local_max_size; - } + + } else if op == "**=" { + + // int64 negative exponent check + if binop_dtype_a == int(64) { + if val: int(64) < 0 { + return new MsgTuple( + "Attempt to exponentiate base of type Int64 to negative exponent", + MsgType.ERROR + ); } } - when "*=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - li *= local_val; - if has_max_bits { - li &= local_max_size; - } + // bigint negative exponent check + if kL == 4 { + if val: bigint < 0 { + return new MsgTuple( + "Attempt to exponentiate base of type BigInt to negative exponent", + MsgType.ERROR + ); } - } - when "//=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - if local_val != 0 { - li /= local_val; - } - else { - li = 0:bigint; + + // preserve max_bits semantics + if l.max_bits != -1 { + const max_size = (1: bigint << l.max_bits); + forall t in la with (var local_val = val: bigint, + var local_max_size = max_size) { + powMod(t, t, local_val, local_max_size); } - if has_max_bits { - li &= local_max_size; + } else { + try { + forall t in la with (var local_val = val: bigint) { + // existing code uses uint exponent when no max_bits + t **= local_val; + } + } catch { + return new MsgTuple( + "Exponentiation too large; use smaller values or set max_bits", + MsgType.ERROR + ); } } + + } else { + // non-bigint normal exponentiation + la **= val; } - when "%=" { - // we can't use li %= val because this can result in negatives - forall li in la with (var local_val = val, var local_max_size = max_size) { - if local_val != 0 { - mod(li, li, local_val); - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } + + } else if op == "//=" { + + if kL == 3 && binop_dtype_a == real(64) { + // NumPy-like float floor-division + [li in la] li = floorDivisionHelper(li, val: real(64)); + } else { + // int/uint/bigint style, preserve div-by-zero->0 behavior + // (including your bigint behavior; mask handled afterward) + if val != 0 { + la /= val; + } else { + la = 0: binop_dtype_a; } } - when "**=" { - if val<0 { - throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); - } - if has_max_bits { - forall li in la with (var local_val = val, var local_max_size = max_size) { - powMod(li, li, local_val, local_max_size + 1); + + } else if op == "%=" { + + if kL == 3 && binop_dtype_a == real(64) { + // NumPy-like float modulo + [li in la] li = modHelper(li, val: real(64)); + } else if kL == 4 { + // Bigint modulo: avoid li %= val (can go negative). Match your old behavior. + forall li in la with (var local_val = val: bigint) { + if local_val != 0 { + mod(li, li, local_val); + } else { + li = 0: bigint; } } - else { - forall li in la with (var local_val = val) { - li **= local_val:uint; - } + } else { + // int/uint modulo, preserve div-by-zero->0 + if val != 0 { + la %= val; + } else { + la = 0: binop_dtype_a; } } - otherwise do return MsgTuple.error(nie); + + } else { + handled = false; } + + // ----------------------------- + // bitwise inplace ops + // ----------------------------- + } else if isBitInplaceOp(op) { + + // Gate should have rejected real here. Keep defensive check. + if kL == 3 { + return new MsgTuple("TypeError", MsgType.ERROR); + } + + // bool handled above; here we are int/uint/bigint + select op { + when ">>=" { la >>= val; } + when "<<=" { la <<= val; } + when "&=" { la &= val; } + when "|=" { la |= val; } + when "^=" { la ^= val; } + otherwise { handled = false; } + } + + } else { + handled = false; + } + + if !handled then return MsgTuple.error(nie); + + // bigint post-mask (keep consistent with opeqvv) + if kL == 4 && l.max_bits != -1 { + const mask = (1: bigint << l.max_bits) - 1; + la &= mask; + } + + return MsgTuple.success(); } - else { - return MsgTuple.error(nie); - } - return MsgTuple.success(); + + when 1 { return new MsgTuple("TypeError", MsgType.ERROR); } + when 2 { return new MsgTuple("TypeError", MsgType.ERROR); } // same return type, finer-grain later if desired + otherwise { return MsgTuple.error(nie); } + } + + return MsgTuple.success(); } + } From 22b115973d16f429b987c0584f05cfd4145e0cd2 Mon Sep 17 00:00:00 2001 From: Ryan Keck Date: Tue, 13 Jan 2026 12:02:00 -0500 Subject: [PATCH 4/4] I think this fixes most issues. I've been looking into unit tests and I don't think I will end up writing super complete ones, but I'll revisit these files later and clean them up better. --- src/BinOp.chpl | 4 +- src/OperatorMsg.chpl | 98 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 82 insertions(+), 20 deletions(-) diff --git a/src/BinOp.chpl b/src/BinOp.chpl index 132fd6bfa9c..7997ff90ad7 100644 --- a/src/BinOp.chpl +++ b/src/BinOp.chpl @@ -18,7 +18,7 @@ module BinOp const omLogger = new Logger(logLevel, logChannel); proc splitType(type dtype) param : int { - // 0 -> bool, 1 -> uint, 2 -> int, 3 -> real + // 0 -> bool, 1 -> uint, 2 -> int, 3 -> real, 4 -> bigint if dtype == bool then return 0; else if dtype == uint(8) then return 1; @@ -37,6 +37,8 @@ module BinOp } proc mySafeCast(type dtype1, type dtype2) type { + // Since splitType can wind up with 4, just cautiously guarding numBits(bigint) + if dtype1==bigint || dtype2==bigint then return bigint; param typeKind1 = splitType(dtype1); param bitSize1 = if dtype1 == bool then 8 else numBits(dtype1); param typeKind2 = splitType(dtype2); diff --git a/src/OperatorMsg.chpl b/src/OperatorMsg.chpl index 51d5b276c02..c12f0634ea7 100644 --- a/src/OperatorMsg.chpl +++ b/src/OperatorMsg.chpl @@ -538,8 +538,48 @@ module OperatorMsg } else if isBitInplaceOp(op) { + if kL == 4 { + + if op == ">>=" || op == "<<=" { + + // If RHS isn't already bigint (possible in vv), cast to bigint *only if allowed by gate*. + // But in practice, the gate should only let through integral-ish RHS for shifts. + const rb = (r.a: bigint); + + // Validate shift counts: non-negative and fits in int + const maxShift = max(int): bigint; + + // Fast reject with reductions so you don't partially mutate la + if || reduce (rb < 0: bigint) { + return new MsgTuple("ValueError: negative shift count", MsgType.ERROR); + } + if || reduce (rb > maxShift) { + return new MsgTuple("ValueError: shift count too large", MsgType.ERROR); + } + + // Apply shift (elementwise) + if op == ">>=" { + forall (li, ri) in zip(la, rb) { + li >>= (ri:int); + } + } else { + forall (li, ri) in zip(la, rb) { + li <<= (ri:int); + } + } + + } else { + select op { + when "&=" { la &= (ra: binop_dtype_a); } + when "|=" { la |= (ra: binop_dtype_a); } + when "^=" { la ^= (ra: binop_dtype_a); } + otherwise { handled = false; } + } + } + } + // Bitwise + shifts must NOT compile for real LHS - if (kL == 0 || kL == 1 || kL == 2) { + else if (kL == 0 || kL == 1 || kL == 2) { select op { when ">>=" { la >>= (ra: binop_dtype_a); } when "<<=" { la <<= (ra: binop_dtype_a); } @@ -548,13 +588,6 @@ module OperatorMsg when "^=" { la ^= (ra: binop_dtype_a); } otherwise { handled = false; } } - } else if (kL == 0 || kL == 1 || kL == 2 || kL == 4) { - select op { - when "&=" { la &= (ra: binop_dtype_a); } - when "|=" { la |= (ra: binop_dtype_a); } - when "^=" { la ^= (ra: binop_dtype_a); } - otherwise { handled = false; } - } } else { // real LHS should not reach here; gate should have blocked it return new MsgTuple("TypeError", MsgType.ERROR); @@ -565,7 +598,7 @@ module OperatorMsg } if !handled then return MsgTuple.error(nie); - if kL == 4 { + if kL == 4 && l.max_bits != -1 { const mask = (1: bigint << l.max_bits) - 1; la &= mask; } @@ -725,7 +758,6 @@ module OperatorMsg [li in la] li = floorDivisionHelper(li, val: real(64)); } else { // int/uint/bigint style, preserve div-by-zero->0 behavior - // (including your bigint behavior; mask handled afterward) if val != 0 { la /= val; } else { @@ -739,7 +771,7 @@ module OperatorMsg // NumPy-like float modulo [li in la] li = modHelper(li, val: real(64)); } else if kL == 4 { - // Bigint modulo: avoid li %= val (can go negative). Match your old behavior. + // Bigint modulo: avoid li %= val (can go negative). forall li in la with (var local_val = val: bigint) { if local_val != 0 { mod(li, li, local_val); @@ -770,14 +802,42 @@ module OperatorMsg return new MsgTuple("TypeError", MsgType.ERROR); } - // bool handled above; here we are int/uint/bigint - select op { - when ">>=" { la >>= val; } - when "<<=" { la <<= val; } - when "&=" { la &= val; } - when "|=" { la |= val; } - when "^=" { la ^= val; } - otherwise { handled = false; } + if kL == 4 && (op == ">>=" || op == "<<=") { + // Convert bigint -> int shift count (reject negative / too large) + if val < 0: bigint then + return new MsgTuple("ValueError: negative shift count", MsgType.ERROR); + + // pick a bound you consider safe; at minimum, must fit in int + const maxShift = max(int): bigint; + if val > maxShift then + return new MsgTuple("ValueError: shift count too large", MsgType.ERROR); + + const sh = val:int; + + if op == ">>=" then la >>= sh; + else la <<= sh; + + } else if kL == 4 { + + select op { + when "&=" { la &= val; } + when "|=" { la |= val; } + when "^=" { la ^= val; } + otherwise { handled = false; } + } + + } else { + + // bool/bigint handled above; here we are int/uint + select op { + when ">>=" { la >>= val; } + when "<<=" { la <<= val; } + when "&=" { la &= val; } + when "|=" { la |= val; } + when "^=" { la ^= val; } + otherwise { handled = false; } + } + } } else {