Skip to content

Commit 85cb344

Browse files
bjoernhaeuserjulien-duponchelle
authored andcommitted
Add support for TINYINT(1) to bool() mapping
1 parent 4733c9d commit 85cb344

File tree

6 files changed

+91
-56
lines changed

6 files changed

+91
-56
lines changed

pymysqlreplication/binlogstream.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ class BinLogStreamReader(object):
1212
'''Connect to replication stream and read event'''
1313

1414
def __init__(self, connection_settings={}, resume_stream=False, blocking=False, only_events=None, server_id=255):
15-
'''
15+
"""
1616
resume_stream: Start for latest event of binlog or from older available event
1717
blocking: Read on stream is blocking
1818
only_events: Array of allowed events
19-
'''
19+
"""
2020
self.__connection_settings = connection_settings
2121
self.__connection_settings['charset'] = 'utf8'
2222

pymysqlreplication/column.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from .constants import FIELD_TYPE
33
from pymysql.util import byte2int, int2byte
44

5+
56
class Column(object):
6-
'''Definition of a column'''
7+
"""Definition of a column"""
78

89
def __init__(self, column_type, column_schema, packet):
910
self.type = column_type
@@ -12,6 +13,7 @@ def __init__(self, column_type, column_schema, packet):
1213
self.character_set_name = column_schema["CHARACTER_SET_NAME"]
1314
self.comment = column_schema["COLUMN_COMMENT"]
1415
self.unsigned = False
16+
self.type_is_bool = False
1517

1618
if column_schema["COLUMN_TYPE"].find("unsigned") != -1:
1719
self.unsigned = True
@@ -41,10 +43,11 @@ def __init__(self, column_type, column_schema, packet):
4143
self.fsp = packet.read_uint8()
4244
elif self.type == FIELD_TYPE.TIME2:
4345
self.fsp = packet.read_uint8()
44-
46+
elif self.type == FIELD_TYPE.TINY and column_schema["COLUMN_TYPE"] == "tinyint(1)":
47+
self.type_is_bool = True
4548

4649
def __read_string_metadata(self, packet, column_schema):
47-
metadata = (packet.read_uint8() << 8) + packet.read_uint8()
50+
metadata = (packet.read_uint8() << 8) + packet.read_uint8()
4851
real_type = metadata >> 8
4952
if real_type == FIELD_TYPE.SET or real_type == FIELD_TYPE.ENUM:
5053
self.type = real_type

pymysqlreplication/packet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,15 @@ def read_length_coded_binary(self):
107107
"""
108108
c = byte2int(self.read(1))
109109
if c == NULL_COLUMN:
110-
return None
110+
return None
111111
if c < UNSIGNED_CHAR_COLUMN:
112-
return c
112+
return c
113113
elif c == UNSIGNED_SHORT_COLUMN:
114114
return self.unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH))
115115
elif c == UNSIGNED_INT24_COLUMN:
116-
return self.unpack_int24(self.read(UNSIGNED_INT24_LENGTH))
116+
return self.unpack_int24(self.read(UNSIGNED_INT24_LENGTH))
117117
elif c == UNSIGNED_INT64_COLUMN:
118-
return self.unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
118+
return self.unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
119119

120120
def read_length_coded_string(self):
121121
"""Read a 'Length Coded String' from the data buffer.

pymysqlreplication/row_event.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .constants import BINLOG
99
from .column import Column
1010

11+
1112
class RowsEvent(BinLogEvent):
1213
def __init__(self, from_packet, event_size, table_map, ctl_connection):
1314
super(RowsEvent, self).__init__(from_packet, event_size, table_map, ctl_connection)
@@ -19,8 +20,8 @@ def __init__(self, from_packet, event_size, table_map, ctl_connection):
1920

2021
#Event V2
2122
if self.event_type == BINLOG.WRITE_ROWS_EVENT or \
22-
self.event_type == BINLOG.DELETE_ROWS_EVENT or \
23-
self.event_type == BINLOG.UPDATE_ROWS_EVENT:
23+
self.event_type == BINLOG.DELETE_ROWS_EVENT or \
24+
self.event_type == BINLOG.UPDATE_ROWS_EVENT:
2425
self.extra_data_length = struct.unpack('<H', self.packet.read(2))[0]
2526
self.extra_data = self.packet.read(self.extra_data_length / 8)
2627

@@ -54,6 +55,9 @@ def _read_column_data(self, null_bitmap):
5455
values[name] = struct.unpack("<B", self.packet.read(1))[0]
5556
else:
5657
values[name] = struct.unpack("<b", self.packet.read(1))[0]
58+
59+
if column.type_is_bool:
60+
values[name] = bool(values[name])
5761
elif column.type == FIELD_TYPE.SHORT:
5862
if unsigned:
5963
values[name] = struct.unpack("<H", self.packet.read(2))[0]
@@ -94,10 +98,11 @@ def _read_column_data(self, null_bitmap):
9498
# For new date format: http://dev.mysql.com/doc/internals/en/date-and-time-data-type-representation.html
9599
elif column.type == FIELD_TYPE.DATETIME2:
96100
values[name] = self.__read_datetime2(column)
97-
elif column.type == FIELD_TYPE.TIME2:
101+
elif column.type == FIELD_TYPE.TIME2:
98102
values[name] = self.__read_time2(column)
99103
elif column.type == FIELD_TYPE.TIMESTAMP2:
100-
values[name] = self.__add_fsp_to_time(datetime.datetime.fromtimestamp(self.packet.read_int_be_by_size(4)), column)
104+
values[name] = self.__add_fsp_to_time(
105+
datetime.datetime.fromtimestamp(self.packet.read_int_be_by_size(4)), column)
101106
elif column.type == FIELD_TYPE.LONGLONG:
102107
if unsigned:
103108
values[name] = self.packet.read_uint64()
@@ -110,7 +115,7 @@ def _read_column_data(self, null_bitmap):
110115
elif column.type == FIELD_TYPE.SET:
111116
#We read set columns as a bitmap telling us which options are enabled
112117
bit_mask = self.packet.read_uint_by_size(column.size)
113-
values[name] = {val for idx,val in enumerate(column.set_values) if bit_mask & 2**idx} or None
118+
values[name] = {val for idx, val in enumerate(column.set_values) if bit_mask & 2 ** idx} or None
114119

115120
elif column.type == FIELD_TYPE.BIT:
116121
values[name] = self.__read_bit(column)
@@ -134,9 +139,9 @@ def __add_fsp_to_time(self, time, column):
134139
if read > 0:
135140
microsecond = self.packet.read_int_be_by_size(read)
136141
if column.fsp % 2:
137-
time = time.replace(microsecond = int(microsecond / 10))
142+
time = time.replace(microsecond=int(microsecond / 10))
138143
else:
139-
time = time.replace(microsecond = microsecond)
144+
time = time.replace(microsecond=microsecond)
140145
return time
141146

142147
def __read_string(self, size, column):
@@ -171,9 +176,9 @@ def __read_bit(self, column):
171176
def __read_time(self):
172177
time = self.packet.read_uint24()
173178
date = datetime.time(
174-
hour = int(time / 10000),
175-
minute = int((time % 10000) / 100),
176-
second = int(time % 100))
179+
hour=int(time / 10000),
180+
minute=int((time % 10000) / 100),
181+
second=int(time % 100))
177182
return date
178183

179184
def __read_time2(self, column):
@@ -187,9 +192,9 @@ def __read_time2(self, column):
187192
24 bits = 3 bytes'''
188193
data = self.packet.read_int_be_by_size(3)
189194
t = datetime.time(
190-
hour = self.__read_binary_slice(data, 2, 10, 24),
191-
minute = self.__read_binary_slice(data, 12, 6, 24),
192-
second = self.__read_binary_slice(data, 18, 6, 24))
195+
hour=self.__read_binary_slice(data, 2, 10, 24),
196+
minute=self.__read_binary_slice(data, 12, 6, 24),
197+
second=self.__read_binary_slice(data, 18, 6, 24))
193198
return self.__add_fsp_to_time(t, column)
194199

195200
def __read_date(self):
@@ -198,9 +203,9 @@ def __read_date(self):
198203
return None
199204

200205
date = datetime.date(
201-
year = (time & ((1 << 15) - 1) << 9) >> 9,
202-
month = (time & ((1 << 4) - 1) << 5) >> 5,
203-
day = (time & ((1 << 5) - 1))
206+
year=(time & ((1 << 15) - 1) << 9) >> 9,
207+
month=(time & ((1 << 4) - 1) << 5) >> 5,
208+
day=(time & ((1 << 5) - 1))
204209
)
205210
return date
206211

@@ -219,12 +224,12 @@ def __read_datetime(self):
219224
return None
220225

221226
date = datetime.datetime(
222-
year = year,
223-
month = month,
224-
day = day,
225-
hour = int(time / 10000),
226-
minute = int((time % 10000) / 100),
227-
second = int(time % 100))
227+
year=year,
228+
month=month,
229+
day=day,
230+
hour=int(time / 10000),
231+
minute=int((time % 10000) / 100),
232+
second=int(time % 100))
228233
return date
229234

230235
def __read_datetime2(self, column):
@@ -241,12 +246,12 @@ def __read_datetime2(self, column):
241246
year_month = self.__read_binary_slice(data, 1, 17, 40)
242247
try:
243248
t = datetime.datetime(
244-
year = int(year_month / 13),
245-
month = year_month % 13,
246-
day = self.__read_binary_slice(data, 18, 5, 40),
247-
hour = self.__read_binary_slice(data, 23, 5, 40),
248-
minute = self.__read_binary_slice(data, 28, 6, 40),
249-
second = self.__read_binary_slice(data, 34, 6, 40))
249+
year=int(year_month / 13),
250+
month=year_month % 13,
251+
day=self.__read_binary_slice(data, 18, 5, 40),
252+
hour=self.__read_binary_slice(data, 23, 5, 40),
253+
minute=self.__read_binary_slice(data, 28, 6, 40),
254+
second=self.__read_binary_slice(data, 34, 6, 40))
250255
except ValueError:
251256
return None
252257
return self.__add_fsp_to_time(t, column)
@@ -278,7 +283,6 @@ def __read_new_decimal(self, column):
278283
res = "-"
279284
self.packet.unread(struct.pack('<B', value ^ 0x80))
280285

281-
282286
size = compressed_bytes[comp_integral]
283287
if size > 0:
284288
value = self.packet.read_int_be_by_size(size) ^ mask
@@ -316,7 +320,7 @@ def __read_binary_slice(self, binary, start, size, data_length):
316320
def _dump(self):
317321
super(RowsEvent, self)._dump()
318322
print("Table: %s.%s" % (self.schema, self.table))
319-
print("Affected columns: %d" % (self.number_of_columns))
323+
print("Affected columns: %d" % self.number_of_columns)
320324
print("Changed rows: %d" % (len(self.rows)))
321325

322326
def _fetch_rows(self):
@@ -332,7 +336,8 @@ def __getattr__(self, name):
332336

333337

334338
class DeleteRowsEvent(RowsEvent):
335-
'''This evenement is trigger when a row in database is removed'''
339+
"""This event is trigger when a row in the database is removed"""
340+
336341
def __init__(self, from_packet, event_size, table_map, ctl_connection):
337342
super(DeleteRowsEvent, self).__init__(from_packet, event_size, table_map, ctl_connection)
338343
self.columns_present_bitmap = self.packet.read((self.number_of_columns + 7) / 8)
@@ -354,7 +359,8 @@ def _dump(self):
354359

355360

356361
class WriteRowsEvent(RowsEvent):
357-
'''This evenement is trigger when a row in database is added'''
362+
"""This event is triggered when a row in database is added"""
363+
358364
def __init__(self, from_packet, event_size, table_map, ctl_connection):
359365
super(WriteRowsEvent, self).__init__(from_packet, event_size, table_map, ctl_connection)
360366
self.columns_present_bitmap = self.packet.read((self.number_of_columns + 7) / 8)
@@ -376,9 +382,10 @@ def _dump(self):
376382

377383

378384
class UpdateRowsEvent(RowsEvent):
379-
'''This evenement is trigger when a row in database change'''
385+
"""This event is triggered when a row in the database is changed"""
386+
380387
def __init__(self, from_packet, event_size, table_map, ctl_connection):
381-
super(UpdateRowsEvent,self).__init__(from_packet, event_size, table_map, ctl_connection)
388+
super(UpdateRowsEvent, self).__init__(from_packet, event_size, table_map, ctl_connection)
382389
#Body
383390
self.columns_present_bitmap = self.packet.read((self.number_of_columns + 7) / 8)
384391
self.columns_present_bitmap2 = self.packet.read((self.number_of_columns + 7) / 8)
@@ -395,7 +402,7 @@ def _fetch_one_row(self):
395402

396403
def _dump(self):
397404
super(UpdateRowsEvent, self)._dump()
398-
print("Affected columns: %d" % (self.number_of_columns))
405+
print("Affected columns: %d" % self.number_of_columns)
399406
print("Values:")
400407
for row in self.rows:
401408
print("--")
@@ -407,19 +414,19 @@ class TableMapEvent(BinLogEvent):
407414
'''This evenement describe the structure of a table.
408415
It's send before a change append on a table.
409416
A end user of the lib should have no usage of this'''
417+
410418
def __init__(self, from_packet, event_size, table_map, ctl_connection):
411419
super(TableMapEvent, self).__init__(from_packet, event_size, table_map, ctl_connection)
412420

413421
# Post-Header
414422
self.table_id = self._read_table_id()
415423
self.flags = struct.unpack('<H', self.packet.read(2))[0]
416424

417-
418425
# Payload
419-
self.schema_length = byte2int(self.packet.read(1))
426+
self.schema_length = byte2int(self.packet.read(1))
420427
self.schema = self.packet.read(self.schema_length).decode()
421428
self.packet.advance(1)
422-
self.table_length = byte2int(self.packet.read(1))
429+
self.table_length = byte2int(self.packet.read(1))
423430
self.table = self.packet.read(self.table_length).decode()
424431
self.packet.advance(1)
425432
self.column_count = self.packet.read_length_coded_binary()
@@ -429,7 +436,7 @@ def __init__(self, from_packet, event_size, table_map, ctl_connection):
429436
if self.table_id in table_map:
430437
self.column_schemas = table_map[self.table_id].column_schemas
431438
else:
432-
self.column_schemas = self.__get_table_informations(self.schema, self.table)
439+
self.column_schemas = self.__get_table_information(self.schema, self.table)
433440

434441
#Read columns meta data
435442
column_types = list(self.packet.read(self.column_count))
@@ -444,7 +451,7 @@ def __init__(self, from_packet, event_size, table_map, ctl_connection):
444451
# TODO: get this informations instead of trashing data
445452
# n NULL-bitmask, length: (column-length * 8) / 7
446453

447-
def __get_table_informations(self, schema, table):
454+
def __get_table_information(self, schema, table):
448455
cur = self._ctl_connection.cursor()
449456
cur.execute("""SELECT * FROM columns WHERE table_schema = %s AND table_name = %s""", (schema, table))
450457
return cur.fetchall()

pymysqlreplication/tests/base.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import copy
44
from pymysqlreplication import BinLogStreamReader
55

6+
67
class PyMySQLReplicationTestCase(unittest.TestCase):
78
'''Test the module. Be carefull it will reset your MySQL server'''
8-
database = {"host":"localhost",
9-
"user":"root",
10-
"passwd":"",
11-
"use_unicode": True,
12-
"charset": "utf8",
13-
"db": "pymysqlreplication_test"
9+
database = {"host": "localhost",
10+
"user": "root",
11+
"passwd": "",
12+
"use_unicode": True,
13+
"charset": "utf8",
14+
"db": "pymysqlreplication_test"
1415
}
1516

1617
def setUp(self):
@@ -57,5 +58,5 @@ def resetBinLog(self):
5758
self.execute("RESET MASTER")
5859
if self.stream is not None:
5960
self.stream.close()
60-
self.stream = BinLogStreamReader(connection_settings = self.database)
61+
self.stream = BinLogStreamReader(connection_settings=self.database)
6162

pymysqlreplication/tests/test_data_type.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,30 @@ def test_tiny(self):
109109
self.assertEqual(event.rows[0]["values"]["id"], 255)
110110
self.assertEqual(event.rows[0]["values"]["test"], -128)
111111

112+
def test_tiny_maps_to_boolean_true(self):
113+
create_query = "CREATE TABLE test (id TINYINT UNSIGNED NOT NULL, test BOOLEAN)"
114+
insert_query = "INSERT INTO test VALUES(1, TRUE)"
115+
event = self.create_and_insert_value(create_query, insert_query)
116+
self.assertEqual(event.rows[0]["values"]["id"], 1)
117+
self.assertEqual(type(event.rows[0]["values"]["test"]), type(True))
118+
self.assertEqual(event.rows[0]["values"]["test"], True)
119+
120+
def test_tiny_maps_to_boolean_false(self):
121+
create_query = "CREATE TABLE test (id TINYINT UNSIGNED NOT NULL, test BOOLEAN)"
122+
insert_query = "INSERT INTO test VALUES(1, FALSE)"
123+
event = self.create_and_insert_value(create_query, insert_query)
124+
self.assertEqual(event.rows[0]["values"]["id"], 1)
125+
self.assertEqual(type(event.rows[0]["values"]["test"]), type(True))
126+
self.assertEqual(event.rows[0]["values"]["test"], False)
127+
128+
def test_tiny_maps_to_none(self):
129+
create_query = "CREATE TABLE test (id TINYINT UNSIGNED NOT NULL, test BOOLEAN)"
130+
insert_query = "INSERT INTO test VALUES(1, NULL)"
131+
event = self.create_and_insert_value(create_query, insert_query)
132+
self.assertEqual(event.rows[0]["values"]["id"], 1)
133+
self.assertEqual(type(event.rows[0]["values"]["test"]), type(None))
134+
self.assertEqual(event.rows[0]["values"]["test"], None)
135+
112136
def test_short(self):
113137
create_query = "CREATE TABLE test (id SMALLINT UNSIGNED NOT NULL, test SMALLINT)"
114138
insert_query = "INSERT INTO test VALUES(65535, -32768)"

0 commit comments

Comments
 (0)