openpilot/tinygrad_repo/test/unit/test_uop_vmin_vmax.py
Vehicle Researcher c5d5c5d1f3 openpilot v0.10.1 release
date: 2025-10-24T00:30:59
master commit: 405631baf9685e171a0dd19547cb763f1b163d18
2025-10-24 00:31:03 -07:00

353 lines
12 KiB
Python

import unittest, math
from tinygrad.uop.ops import UOp, Ops
from tinygrad.dtype import dtypes, Invalid
class TestVminVmaxProperties(unittest.TestCase):
def test_vmin_vmax_constant(self):
# vmin and vmax for a constant
uop = UOp.const(dtypes.int32, 42)
self.assertEqual(uop.vmin, 42)
self.assertEqual(uop.vmax, 42)
def test_vmin_vmax_cmpne(self):
uop = UOp.const(dtypes.int32, 42)
def test_bool(u, x):
self.assertEqual(u.vmin, x)
self.assertEqual(u.vmax, x)
test_bool(uop != 42, False)
test_bool(uop != 43, True)
test_bool(uop != 41, True)
def test_vmin_vmax_addition_with_variable(self):
# vmin and vmax for addition with a variable
x = UOp.variable('x', 10, 20)
uop = x + 5
self.assertEqual(uop.vmin, 15)
self.assertEqual(uop.vmax, 25)
def test_vmin_vmax_subtraction_with_variable(self):
x = UOp.variable('x', 10, 20)
uop = x - 5
self.assertEqual(uop.vmin, 5)
self.assertEqual(uop.vmax, 15)
uop = 5 - x
self.assertEqual(uop.vmin, -15)
self.assertEqual(uop.vmax, -5)
def test_vmin_vmax_and_with_variable(self):
x = UOp.variable('x', 10, 20)
uop = x & 5
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 5)
# this can be improved
uop = x & 15
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 15)
# this can be improved
uop = x & 32
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 20)
def test_vmin_vmax_multiplication_with_variable(self):
# vmin and vmax for multiplication with a variable
x = UOp.variable('x', -3, 4)
uop = x * 2
self.assertEqual(uop.vmin, -6)
self.assertEqual(uop.vmax, 8)
def test_vmin_vmax_variable_inside_special(self):
uop = UOp(Ops.SPECIAL, dtypes.int, arg='gidx0', src=(UOp(Ops.DEFINE_VAR, dtypes.int, arg=('i', 1, 10)),))
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 9)
def test_vmin_vmax_multiplication_0_inf(self):
# vmin and vmax for multiplication with a variable
x = UOp.const(dtypes.float, 0.0)
y = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float)
uop = x * y
# TODO: these should be 0, but definitely should not be nan
self.assertEqual(uop.vmin, -math.inf)
self.assertEqual(uop.vmax, math.inf)
def test_vmin_vmax_with_negative_multiplication(self):
# vmin and vmax when multiplying by a negative number
x = UOp.variable('x', 2, 5)
uop = x * -3
self.assertEqual(uop.vmin, -15)
self.assertEqual(uop.vmax, -6)
def test_vmin_vmax_with_negative_multiplication2(self):
# vmin and vmax when multiplying by a negative number
x = UOp.variable('x', -2, 5)
uop = x * -3
self.assertEqual(uop.vmin, -15)
self.assertEqual(uop.vmax, 6)
def test_vmin_vmax_nested_min_max(self):
# vmin and vmax with nested min/max operations
x = UOp.variable('x', 0, 10)
uop = x.maximum(5).minimum(8)
self.assertEqual(uop.vmin, 5)
self.assertEqual(uop.vmax, 8)
def test_vmin_vmax_where(self):
x = UOp.variable('x', 0, 10)
y = UOp.variable('y', 1, 11)
z = UOp.variable('z', 2, 12)
uop = (x<5).where(y, z)
self.assertEqual(uop.vmin, 1)
self.assertEqual(uop.vmax, 12)
def test_vmin_vmax_shl(self):
x = UOp.variable('x', 0, 10) << 5
self.assertEqual(x.vmin, 0)
self.assertEqual(x.vmax, 10 << 5)
def test_vmin_vmax_shr(self):
x = UOp.variable('x', 0, 10) >> 2
self.assertEqual(x.vmin, 0)
self.assertEqual(x.vmax, 10 >> 2)
def test_vmin_vmax_cast(self):
x = UOp.variable('x', -10, 10, dtypes.int)
x_float = x.cast(dtypes.float)
self.assertEqual(x_float.vmin, -10)
self.assertEqual(x_float.vmax, 10)
x_bool = x.cast(dtypes.bool)
self.assertEqual(x_bool.vmin, False)
self.assertEqual(x_bool.vmax, True)
x_uint = x.cast(dtypes.uint)
self.assertEqual(x_uint.vmin, dtypes.min(dtypes.uint))
self.assertEqual(x_uint.vmax, dtypes.max(dtypes.uint))
def test_vmin_vmax_invalid(self):
i = UOp.invalid()
self.assertNotEqual(i.vmin, i.vmax)
def test_vmin_vmax_invalid_vconst(self):
x = UOp.const(dtypes.index.vec(4), (0, 4, Invalid, Invalid))
self.assertLess(x.vmin, 0)
self.assertGreater(x.vmax, 4)
class TestVminVmaxDivMod(unittest.TestCase):
def test_vmin_vmax_division_positive(self):
# vmin and vmax for division of a variable by a positive constant
x = UOp.variable('x', 10, 20)
uop = x // 2
self.assertEqual(uop.vmin, 5)
self.assertEqual(uop.vmax, 10)
def test_vmin_vmax_division_negative(self):
# vmin and vmax for division of a variable by a negative constant
# always positive
x = UOp.variable('x', 10, 20)
uop = x // -2
self.assertEqual(uop.vmin, -10)
self.assertEqual(uop.vmax, -5)
uop = x // -3
self.assertEqual(uop.vmin, -6)
self.assertEqual(uop.vmax, -3)
# always negative
x = UOp.variable('x', -20, -10)
uop = x // -2
self.assertEqual(uop.vmin, 5)
self.assertEqual(uop.vmax, 10)
uop = x // -3
self.assertEqual(uop.vmin, 3)
self.assertEqual(uop.vmax, 6)
# cross 0
x = UOp.variable('x', -10, 10)
uop = x // -2
self.assertEqual(uop.vmin, -5)
self.assertEqual(uop.vmax, 5)
uop = x // -3
self.assertEqual(uop.vmin, -3)
self.assertEqual(uop.vmax, 3)
def test_vmin_vmax_div_symbolic(self):
x = UOp.variable('x', 1, 10)
y = UOp.variable('y', 3, 5)
self.assertEqual((x//y).vmin, 0)
self.assertEqual((x//y).vmax, 3)
self.assertEqual(((-x)//y).vmin, -3)
self.assertEqual(((-x)//y).vmax, 0)
self.assertEqual((x//(-y)).vmin, -3)
self.assertEqual((x//(-y)).vmax, 0)
self.assertEqual(((-x)//(-y)).vmin, 0)
self.assertEqual(((-x)//(-y)).vmax, 3)
self.assertEqual((100//y).vmin, 20)
self.assertEqual((100//y).vmax, 33)
self.assertEqual(((-100)//y).vmin, -33)
self.assertEqual(((-100)//y).vmax, -20)
self.assertEqual((100//(-y)).vmin, -33)
self.assertEqual((100//(-y)).vmax, -20)
self.assertEqual(((-100)//(-y)).vmin, 20)
self.assertEqual(((-100)//(-y)).vmax, 33)
def test_vmin_vmax_mod_positive(self):
# vmin and vmax for modulo of a variable by a positive constant
positive = UOp.variable('positive', 10, 20)
uop = positive % 3
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 2)
negative = UOp.variable('negative', -20, -10)
uop = negative % 3
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmax, 0)
mixed = UOp.variable('mixed', -20, 20)
uop = mixed % 3
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmax, 2)
def test_vmin_vmax_mod_negative(self):
# vmin and vmax for modulo of a variable by a negative constant
positive = UOp.variable('positive', 10, 20)
uop = positive % -3
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 2)
negative = UOp.variable('negative', -20, -10)
uop = negative % -3
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmax, 0)
mixed = UOp.variable('mixed', -20, 20)
uop = mixed % -3
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmax, 2)
class TestVminVmaxVConst(unittest.TestCase):
def test_vmin_vmax_vconst_single_element(self):
# vmin and vmax for a single-element vector constant
uop = UOp.const(dtypes.int32.vec(1), (42,))
self.assertEqual(uop.vmin, 42)
self.assertEqual(uop.vmax, 42)
def test_vmin_vmax_vconst_multiple_elements(self):
# vmin and vmax for a multi-element vector constant
uop = UOp.const(dtypes.int32.vec(4), (10, 20, -5, 7))
self.assertEqual(uop.vmin, -5)
self.assertEqual(uop.vmax, 20)
def test_vmin_vmax_vconst_all_equal(self):
# vmin and vmax for a vector where all elements are equal
uop = UOp.const(dtypes.int32.vec(3), (7, 7, 7))
self.assertEqual(uop.vmin, 7)
self.assertEqual(uop.vmax, 7)
def test_vmin_vmax_vconst_with_negative_values(self):
# vmin and vmax for a vector constant containing negative values
uop = UOp.const(dtypes.int32.vec(4), (-10, -20, -5, -15))
self.assertEqual(uop.vmin, -20)
self.assertEqual(uop.vmax, -5)
def test_vmin_vmax_vconst_with_floats(self):
# vmin and vmax for a vector constant of float values
uop = UOp.const(dtypes.float32.vec(3), (1.5, -3.2, 0.0))
self.assertEqual(uop.vmin, -3.2)
self.assertEqual(uop.vmax, 1.5)
def test_vmin_vmax_vconst_with_bools(self):
# vmin and vmax for a vector constant of bool values
uop = UOp.const(dtypes.bool.vec(3), (True, False, False))
self.assertIs(uop.vmin, False)
self.assertIs(uop.vmax, True)
def test_vmin_vmax_vector_with_gep(self):
# vmin and vmax for a vector constant of bool values
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
idx = UOp.const(dtypes.int, 0)
val = UOp(Ops.LOAD, dtypes.int.vec(2), (d1.index(idx),))
uop = (val // 32).gep(0)
self.assertEqual(uop.vmin, -67108864)
self.assertEqual(uop.vmax, 67108863)
class TestConstFactor(unittest.TestCase):
def test_const_factor_constant(self):
# const_factor for a constant
uop = UOp.const(dtypes.int32, 42)
self.assertEqual(uop.const_factor(), 42)
def test_const_factor_addition(self):
# const_factor for an addition of constants
uop = UOp.const(dtypes.int32, 30) + UOp.const(dtypes.int32, 12)
self.assertEqual(uop.const_factor(), 6) # GCD(30, 12) = 6
def test_const_factor_multiplication(self):
# const_factor for a multiplication of constants
uop = UOp.const(dtypes.int32, 5) * UOp.const(dtypes.int32, 7)
self.assertEqual(uop.const_factor(), 5) # For multiplication, it's one of the factors
def test_const_factor_with_variable(self):
# const_factor for an expression involving a variable
x = UOp.variable('x', 10, 20)
uop = x * 3
self.assertEqual(uop.const_factor(), 3)
def test_const_factor_division(self):
# const_factor for an expression with division
x = UOp.variable('x', 10, 20)
uop = x // 4
self.assertEqual(uop.const_factor(), 1) # Division reduces the const_factor to 1
def test_const_factor_multiplication_of_var_and_const(self):
# const_factor for multiplication of a variable and a constant
x = UOp.variable('x', 6, 18)
uop = x * 4
self.assertEqual(uop.const_factor(), 4) # Constant factor 4
@unittest.skip("broken")
def test_const_factor_multiplication_of_consts_and_vars(self):
# Multiplying constants and variables
x = UOp.variable('x', 10, 20)
uop = (x * 3) * 5
self.assertEqual(uop.const_factor(), 15) # Constant multipliers are combined (3 * 5 = 15)
class TestDivides(unittest.TestCase):
def test_divides_constant_exact(self):
# Divides a constant by an exact divisor
uop = UOp.const(dtypes.int32, 42)
result = uop.divides(7)
self.assertIsNotNone(result)
self.assertEqual(result.const_factor(), 6) # 42 / 7 = 6
def test_divides_constant_inexact(self):
# Try to divide a constant by a non-exact divisor
uop = UOp.const(dtypes.int32, 42)
result = uop.divides(5)
self.assertIsNone(result) # 42 is not divisible by 5
@unittest.skip("broken")
def test_divides_variable_and_constant(self):
# Multiplying a variable by a constant, then dividing by the same constant
x = UOp.variable('x', 10, 20)
uop = x * 6
result = uop.divides(6)
self.assertIsNotNone(result)
self.assertEqual(result, x) # (x * 6) / 6 = x
def test_divides_complex_expression(self):
# Dividing a more complex expression
x = UOp.variable('x', 10, 20)
uop = (x * 6) + 18
result = uop.divides(6)
self.assertIsNotNone(result)
self.assertEqual(result.const_factor(), 1) # (x + 3), const_factor is 1
def test_divides_with_inexact_factors(self):
# Multiplying by a constant but dividing by a non-exact divisor
x = UOp.variable('x', 15, 45)
uop = x * 4
result = uop.divides(3)
self.assertIsNone(result) # Cannot divide by 3, since 4 is not divisible by 3
if __name__ == '__main__':
unittest.main()