Skip to content

Commit 57365d8

Browse files
stinosdpgeorge
authored andcommitted
py/objarray: Prohibit comparison of mismatching types.
Array equality is defined as each element being equal but to keep code size down MicroPython implements a binary comparison. This can only be used correctly for elements with the same binary layout though so turn it into an NotImplementedError when comparing types for which the binary comparison yielded incorrect results: types with different sizes, and floating point numbers because nan != nan.
1 parent 6affcb0 commit 57365d8

File tree

4 files changed

+54
-1
lines changed

4 files changed

+54
-1
lines changed

py/objarray.c

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,16 @@ STATIC mp_obj_t array_unary_op(mp_unary_op_t op, mp_obj_t o_in) {
258258
}
259259
}
260260

261+
STATIC int typecode_for_comparison(int typecode) {
262+
if (typecode == BYTEARRAY_TYPECODE) {
263+
typecode = 'B';
264+
}
265+
if (typecode <= 'Z') {
266+
typecode += 32; // to lowercase
267+
}
268+
return typecode;
269+
}
270+
261271
STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
262272
mp_obj_array_t *lhs = MP_OBJ_TO_PTR(lhs_in);
263273
switch (op) {
@@ -319,7 +329,20 @@ STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs
319329
if (!mp_get_buffer(rhs_in, &rhs_bufinfo, MP_BUFFER_READ)) {
320330
return mp_const_false;
321331
}
322-
return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len));
332+
// mp_seq_cmp_bytes is used so only compatible representations can be correctly compared.
333+
// The type doesn't matter: array/bytearray/str/bytes all have the same buffer layout, so
334+
// just check if the typecodes are compatible; for testing equality the types should have the
335+
// same code except for signedness, and not be floating point because nan never equals nan.
336+
// Note that typecode_for_comparison always returns lowercase letters to save code size.
337+
// No need for (& TYPECODE_MASK) here: xxx_get_buffer already takes care of that.
338+
const int lhs_code = typecode_for_comparison(lhs_bufinfo.typecode);
339+
const int rhs_code = typecode_for_comparison(rhs_bufinfo.typecode);
340+
if (lhs_code == rhs_code && lhs_code != 'f' && lhs_code != 'd') {
341+
return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len));
342+
}
343+
// mp_obj_equal_not_equal treats returning MP_OBJ_NULL as 'fall back to pointer comparison'
344+
// for MP_BINARY_OP_EQUAL but that is incompatible with CPython.
345+
mp_raise_NotImplementedError(NULL);
323346
}
324347

325348
default:

tests/basics/array1.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,23 @@
4141
# equality (CPython requires both sides are array)
4242
print(bytes(array.array('b', [0x61, 0x62, 0x63])) == b'abc')
4343
print(array.array('b', [0x61, 0x62, 0x63]) == b'abc')
44+
print(array.array('B', [0x61, 0x62, 0x63]) == b'abc')
4445
print(array.array('b', [0x61, 0x62, 0x63]) != b'abc')
4546
print(array.array('b', [0x61, 0x62, 0x63]) == b'xyz')
4647
print(array.array('b', [0x61, 0x62, 0x63]) != b'xyz')
4748
print(b'abc' == array.array('b', [0x61, 0x62, 0x63]))
49+
print(b'abc' == array.array('B', [0x61, 0x62, 0x63]))
4850
print(b'abc' != array.array('b', [0x61, 0x62, 0x63]))
4951
print(b'xyz' == array.array('b', [0x61, 0x62, 0x63]))
5052
print(b'xyz' != array.array('b', [0x61, 0x62, 0x63]))
5153

54+
compatible_typecodes = []
55+
for t in ["b", "h", "i", "l", "q"]:
56+
compatible_typecodes.append((t, t))
57+
compatible_typecodes.append((t, t.upper()))
58+
for a, b in compatible_typecodes:
59+
print(array.array(a, [1, 2]) == array.array(b, [1, 2]))
60+
5261
class X(array.array):
5362
pass
5463

tests/basics/array_micropython.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,15 @@
1717
a = array.array('P')
1818
a.append(1)
1919
print(a[0])
20+
21+
# comparison between mismatching binary layouts is not implemented
22+
typecodes = ["b", "h", "i", "l", "q", "P", "O", "S", "f", "d"]
23+
for a in typecodes:
24+
for b in typecodes:
25+
if a == b and a not in ["f", "d"]:
26+
continue
27+
try:
28+
array.array(a) == array.array(b)
29+
print('FAIL')
30+
except NotImplementedError:
31+
pass
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
categories: Modules,array
3+
description: Comparison between different typecodes not supported
4+
cause: Code size
5+
workaround: Compare individual elements
6+
"""
7+
import array
8+
9+
array.array("b", [1, 2]) == array.array("i", [1, 2])

0 commit comments

Comments
 (0)