@@ -34,6 +34,11 @@ cdef class CoreProtocol:
3434
3535 self ._reset_result()
3636
37+  cpdef is_in_transaction(self ):
38+  #  PQTRANS_INTRANS = idle, within transaction block
39+  #  PQTRANS_INERROR = idle, within failed transaction
40+  return  self .xact_status in  (PQTRANS_INTRANS, PQTRANS_INERROR)
41+ 
3742 cdef _read_server_messages(self ):
3843 cdef:
3944 char  mtype
@@ -263,27 +268,16 @@ cdef class CoreProtocol:
263268 elif  mtype ==  b' Z' 
264269 #  ReadyForQuery
265270 self ._parse_msg_ready_for_query()
266-  if  self .result_type ==  RESULT_FAILED:
267-  self ._push_result()
268-  else :
269-  try :
270-  buf =  < WriteBuffer> next(self ._execute_iter)
271-  except  StopIteration :
272-  self ._push_result()
273-  except  Exception  as  e:
274-  self .result_type =  RESULT_FAILED
275-  self .result =  e
276-  self ._push_result()
277-  else :
278-  #  Next iteration over the executemany() arg sequence
279-  self ._send_bind_message(
280-  self ._execute_portal_name, self ._execute_stmt_name,
281-  buf, 0 )
271+  self ._push_result()
282272
283273 elif  mtype ==  b' I' 
284274 #  EmptyQueryResponse
285275 self .buffer.discard_message()
286276
277+  elif  mtype ==  b' 1' 
278+  #  ParseComplete
279+  self .buffer.discard_message()
280+ 
287281 cdef _process__bind(self , char  mtype):
288282 if  mtype ==  b' E' 
289283 #  ErrorResponse
@@ -780,6 +774,17 @@ cdef class CoreProtocol:
780774 if  self .con_status !=  CONNECTION_OK:
781775 raise  apg_exc.InternalClientError(' not connected' 
782776
777+  cdef WriteBuffer _build_parse_message(self , str  stmt_name, str  query):
778+  cdef WriteBuffer buf
779+ 
780+  buf =  WriteBuffer.new_message(b' P' 
781+  buf.write_str(stmt_name, self .encoding)
782+  buf.write_str(query, self .encoding)
783+  buf.write_int16(0 )
784+ 
785+  buf.end_message()
786+  return  buf
787+ 
783788 cdef WriteBuffer _build_bind_message(self , str  portal_name,
784789 str  stmt_name,
785790 WriteBuffer bind_data):
@@ -795,6 +800,25 @@ cdef class CoreProtocol:
795800 buf.end_message()
796801 return  buf
797802
803+  cdef WriteBuffer _build_empty_bind_data(self ):
804+  cdef WriteBuffer buf
805+  buf =  WriteBuffer.new()
806+  buf.write_int16(0 ) #  The number of parameter format codes
807+  buf.write_int16(0 ) #  The number of parameter values
808+  buf.write_int16(0 ) #  The number of result-column format codes
809+  return  buf
810+ 
811+  cdef WriteBuffer _build_execute_message(self , str  portal_name,
812+  int32_t limit):
813+  cdef WriteBuffer buf
814+ 
815+  buf =  WriteBuffer.new_message(b' E' 
816+  buf.write_str(portal_name, self .encoding) #  name of the portal
817+  buf.write_int32(limit) #  number of rows to return; 0 - all
818+ 
819+  buf.end_message()
820+  return  buf
821+ 
798822 #  API for subclasses
799823
800824 cdef _connect(self ):
@@ -845,12 +869,7 @@ cdef class CoreProtocol:
845869 self ._ensure_connected()
846870 self ._set_state(PROTOCOL_PREPARE)
847871
848-  buf =  WriteBuffer.new_message(b' P' 
849-  buf.write_str(stmt_name, self .encoding)
850-  buf.write_str(query, self .encoding)
851-  buf.write_int16(0 )
852-  buf.end_message()
853-  packet =  buf
872+  packet =  self ._build_parse_message(stmt_name, query)
854873
855874 buf =  WriteBuffer.new_message(b' D' 
856875 buf.write_byte(b' S' 
@@ -872,10 +891,7 @@ cdef class CoreProtocol:
872891 buf =  self ._build_bind_message(portal_name, stmt_name, bind_data)
873892 packet =  buf
874893
875-  buf =  WriteBuffer.new_message(b' E' 
876-  buf.write_str(portal_name, self .encoding) #  name of the portal
877-  buf.write_int32(limit) #  number of rows to return; 0 - all
878-  buf.end_message()
894+  buf =  self ._build_execute_message(portal_name, limit)
879895 packet.write_buffer(buf)
880896
881897 packet.write_bytes(SYNC_MESSAGE)
@@ -894,11 +910,8 @@ cdef class CoreProtocol:
894910
895911 self ._send_bind_message(portal_name, stmt_name, bind_data, limit)
896912
897-  cdef _bind_execute_many(self , str  portal_name, str  stmt_name,
898-  object  bind_data):
899- 
900-  cdef WriteBuffer buf
901- 
913+  cdef bint _bind_execute_many(self , str  portal_name, str  stmt_name,
914+  object  bind_data):
902915 self ._ensure_connected()
903916 self ._set_state(PROTOCOL_BIND_EXECUTE_MANY)
904917
@@ -907,17 +920,88 @@ cdef class CoreProtocol:
907920 self ._execute_iter =  bind_data
908921 self ._execute_portal_name =  portal_name
909922 self ._execute_stmt_name =  stmt_name
923+  return  self ._bind_execute_many_more(True )
910924
911-  try :
912-  buf =  < WriteBuffer> next(bind_data)
913-  except  StopIteration :
914-  self ._push_result()
915-  except  Exception  as  e:
916-  self .result_type =  RESULT_FAILED
917-  self .result =  e
918-  self ._push_result()
919-  else :
920-  self ._send_bind_message(portal_name, stmt_name, buf, 0 )
925+  cdef bint _bind_execute_many_more(self , bint first = False ):
926+  cdef:
927+  WriteBuffer packet
928+  WriteBuffer buf
929+  list  buffers =  []
930+ 
931+  #  as we keep sending, the server may return an error early
932+  if  self .result_type ==  RESULT_FAILED:
933+  self ._write(SYNC_MESSAGE)
934+  return  False 
935+ 
936+  #  collect up to four 32KB buffers to send
937+  #  https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051
938+  while  len (buffers) <  _EXECUTE_MANY_BUF_NUM:
939+  packet =  WriteBuffer.new()
940+ 
941+  #  fill one 32KB buffer
942+  while  packet.len() <  _EXECUTE_MANY_BUF_SIZE:
943+  try :
944+  #  grab one item from the input
945+  buf =  < WriteBuffer> next(self ._execute_iter)
946+ 
947+  #  reached the end of the input
948+  except  StopIteration :
949+  if  first:
950+  #  if we never send anything, simply set the result
951+  self ._push_result()
952+  else :
953+  #  otherwise, append SYNC and send the buffers
954+  packet.write_bytes(SYNC_MESSAGE)
955+  buffers.append(packet)
956+  self ._writelines(buffers)
957+  return  False 
958+ 
959+  #  error in input, give up the buffers and cleanup
960+  except  Exception  as  ex:
961+  self .result_type =  RESULT_FAILED
962+  self .result =  ex
963+  if  first:
964+  self ._push_result()
965+  elif  self .is_in_transaction():
966+  #  we're in an explicit transaction, just SYNC
967+  self ._write(SYNC_MESSAGE)
968+  else :
969+  #  In an implicit transaction, if `ignore_till_sync`,
970+  #  `ROLLBACK` will be ignored and `Sync` will restore
971+  #  the state; or the transaction will be rolled back
972+  #  with a warning saying that there was no transaction,
973+  #  but rollback is done anyway, so we could safely
974+  #  ignore this warning.
975+  #  GOTCHA: simple query message will be ignored if
976+  #  `ignore_till_sync` is set.
977+  buf =  self ._build_parse_message(' ' ' ROLLBACK' 
978+  buf.write_buffer(self ._build_bind_message(
979+  ' ' ' ' self ._build_empty_bind_data()))
980+  buf.write_buffer(self ._build_execute_message(' ' 0 ))
981+  buf.write_bytes(SYNC_MESSAGE)
982+  self ._write(buf)
983+  return  False 
984+ 
985+  #  all good, write to the buffer
986+  first =  False 
987+  packet.write_buffer(
988+  self ._build_bind_message(
989+  self ._execute_portal_name,
990+  self ._execute_stmt_name,
991+  buf,
992+  )
993+  )
994+  packet.write_buffer(
995+  self ._build_execute_message(self ._execute_portal_name, 0 ,
996+  )
997+  )
998+ 
999+  #  collected one buffer
1000+  buffers.append(packet)
1001+ 
1002+  #  write to the wire, and signal the caller for more to send
1003+  self ._writelines(buffers)
1004+  return  True 
9211005
9221006 cdef _execute(self , str  portal_name, int32_t limit):
9231007 cdef WriteBuffer buf
@@ -927,10 +1011,7 @@ cdef class CoreProtocol:
9271011
9281012 self .result =  []
9291013
930-  buf =  WriteBuffer.new_message(b' E' 
931-  buf.write_str(portal_name, self .encoding) #  name of the portal
932-  buf.write_int32(limit) #  number of rows to return; 0 - all
933-  buf.end_message()
1014+  buf =  self ._build_execute_message(portal_name, limit)
9341015
9351016 buf.write_bytes(SYNC_MESSAGE)
9361017
@@ -1013,6 +1094,9 @@ cdef class CoreProtocol:
10131094 cdef _write(self , buf):
10141095 raise  NotImplementedError 
10151096
1097+  cdef _writelines(self , list  buffers):
1098+  raise  NotImplementedError 
1099+ 
10161100 cdef _decode_row(self , const char *  buf, ssize_t buf_len):
10171101 pass 
10181102
0 commit comments