Skip to content

Commit 6fe0010

Browse files
committed
PYTHON-1650 Always increment txnNumber before starting a retryable write
1 parent dea14be commit 6fe0010

File tree

4 files changed

+69
-10
lines changed

4 files changed

+69
-10
lines changed

pymongo/bulk.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(self, collection, ordered, bypass_document_validation):
160160
self.uses_array_filters = False
161161
self.is_retryable = True
162162
self.retrying = False
163+
self.started_retryable_write = False
163164
# Extra state so that we know where to pick up on a retry attempt.
164165
self.current_run = None
165166

@@ -275,6 +276,11 @@ def _execute_command(self, generator, write_concern, session,
275276

276277
while run.idx_offset < len(run.ops):
277278
if session:
279+
# Start a new retryable write unless one was already
280+
# started for this command.
281+
if retryable and not self.started_retryable_write:
282+
session._start_retryable_write()
283+
self.started_retryable_write = True
278284
session._apply_to(cmd, retryable, ReadPreference.PRIMARY)
279285
sock_info.send_cluster_time(cmd, session, client)
280286
check_keys = run.op_type == _INSERT
@@ -300,6 +306,8 @@ def _execute_command(self, generator, write_concern, session,
300306
_merge_command(run, full_result, run.idx_offset, result)
301307
# We're no longer in a retry once a command succeeds.
302308
self.retrying = False
309+
self.started_retryable_write = False
310+
303311
if self.ordered and "writeErrors" in result:
304312
break
305313
run.idx_offset += len(to_send)

pymongo/client_session.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def start_transaction(self, read_concern=None, write_concern=None,
367367
self._transaction.opts = TransactionOptions(
368368
read_concern, write_concern, read_preference)
369369
self._transaction.state = _TxnState.STARTING
370-
self._server_session._transaction_id += 1
370+
self._start_retryable_write()
371371
self._transaction.transaction_id = self._server_session.transaction_id
372372
return _TransactionContext(self)
373373

@@ -544,7 +544,6 @@ def _apply_to(self, command, is_retryable, read_preference):
544544
self._transaction.state = _TxnState.NONE
545545

546546
if is_retryable:
547-
self._server_session._transaction_id += 1
548547
command['txnNumber'] = self._server_session.transaction_id
549548
return
550549

@@ -574,9 +573,9 @@ def _apply_to(self, command, is_retryable, read_preference):
574573
command['txnNumber'] = self._server_session.transaction_id
575574
command['autocommit'] = False
576575

577-
def _retry_transaction_id(self):
576+
def _start_retryable_write(self):
578577
self._check_ended()
579-
self._server_session.retry_transaction_id()
578+
self._server_session.inc_transaction_id()
580579

581580

582581
class _ServerSession(object):
@@ -597,8 +596,8 @@ def transaction_id(self):
597596
"""Positive 64-bit integer."""
598597
return Int64(self._transaction_id)
599598

600-
def retry_transaction_id(self):
601-
self._transaction_id -= 1
599+
def inc_transaction_id(self):
600+
self._transaction_id += 1
602601

603602

604603
class _ServerSessionPool(collections.deque):

pymongo/mongo_client.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,6 +1197,14 @@ def _retry_with_session(self, retryable, func, session, bulk):
11971197

11981198
def is_retrying():
11991199
return bulk.retrying if bulk else retrying
1200+
# Increment the transaction id up front to ensure any retry attempt
1201+
# will use the proper txnNumber, even if server or socket selection
1202+
# fails before the command can be sent.
1203+
if retryable:
1204+
session._start_retryable_write()
1205+
if bulk:
1206+
bulk.started_retryable_write = True
1207+
12001208
while True:
12011209
try:
12021210
server = self._get_topology().select_server(
@@ -1211,9 +1219,6 @@ def is_retrying():
12111219
# not support sessions raise the last error.
12121220
raise last_error
12131221
retryable = False
1214-
if is_retrying():
1215-
# Reset the transaction id and retry the operation.
1216-
session._retry_transaction_id()
12171222
return func(session, sock_info, retryable)
12181223
except ServerSelectionTimeoutError:
12191224
if is_retrying():

test/test_retryable_writes.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Test retryable writes."""
1616

17+
import copy
1718
import json
1819
import os
1920
import sys
@@ -169,7 +170,7 @@ def create_tests():
169170
create_tests()
170171

171172

172-
def retryable_single_statement_ops(coll):
173+
def _retryable_single_statement_ops(coll):
173174
return [
174175
(coll.bulk_write, [[InsertOne({}), InsertOne({})]], {}),
175176
(coll.bulk_write, [[InsertOne({}),
@@ -188,6 +189,11 @@ def retryable_single_statement_ops(coll):
188189
(coll.find_one_and_replace, [{}, {'a': 3}], {}),
189190
(coll.find_one_and_update, [{}, {'$set': {'a': 1}}], {}),
190191
(coll.find_one_and_delete, [{}, {}], {}),
192+
]
193+
194+
195+
def retryable_single_statement_ops(coll):
196+
return _retryable_single_statement_ops(coll) + [
191197
# Deprecated methods.
192198
# Insert with single or multiple documents.
193199
(coll.insert, [{}], {}),
@@ -500,5 +506,46 @@ def test_batch_splitting_retry_fails(self):
500506
self.assertEqual(coll.find_one(projection={'_id': True}), {'_id': 1})
501507

502508

509+
# TODO: Make this a real integration test where we stepdown the primary.
510+
class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest):
511+
@client_context.require_version_min(3, 6)
512+
@client_context.require_replica_set
513+
def test_increment_transaction_id_without_sending_command(self):
514+
"""Test that the txnNumber field is properly incremented, even when
515+
the first attempt fails before sending the command.
516+
"""
517+
listener = OvertCommandListener()
518+
client = rs_or_single_client(
519+
retryWrites=True, event_listeners=[listener])
520+
topology = client._topology
521+
select_server = topology.select_server
522+
523+
def raise_connection_err_select_server(*args, **kwargs):
524+
# Raise ConnectionFailure on the first attempt and perform
525+
# normal selection on the retry attempt.
526+
topology.select_server = select_server
527+
raise ConnectionFailure('Connection refused')
528+
529+
for method, args, kwargs in _retryable_single_statement_ops(
530+
client.db.retryable_write_test):
531+
listener.results.clear()
532+
topology.select_server = raise_connection_err_select_server
533+
with client.start_session() as session:
534+
kwargs = copy.deepcopy(kwargs)
535+
kwargs['session'] = session
536+
msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs)
537+
initial_txn_id = session._server_session.transaction_id
538+
539+
# Each operation should fail on the first attempt and succeed
540+
# on the second.
541+
method(*args, **kwargs)
542+
self.assertEqual(len(listener.results['started']), 1, msg)
543+
retry_cmd = listener.results['started'][0].command
544+
sent_txn_id = retry_cmd['txnNumber']
545+
final_txn_id = session._server_session.transaction_id
546+
self.assertEqual(Int64(initial_txn_id + 1), sent_txn_id, msg)
547+
self.assertEqual(sent_txn_id, final_txn_id, msg)
548+
549+
503550
if __name__ == '__main__':
504551
unittest.main()

0 commit comments

Comments
 (0)