@@ -22,13 +22,13 @@ cdef class CoreProtocol:
2222 self .xact_status = PQTRANS_IDLE
2323 self .encoding = ' utf-8'
2424
25- # executemany support data
26- self ._execute_iter = None
27- self ._execute_portal_name = None
28- self ._execute_stmt_name = None
29-
3025 self ._reset_result()
3126
27+ cpdef is_in_transaction(self ):
28+ # PQTRANS_INTRANS = idle, within transaction block
29+ # PQTRANS_INERROR = idle, within failed transaction
30+ return self .xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
31+
3232 cdef _read_server_messages(self ):
3333 cdef:
3434 char mtype
@@ -253,22 +253,7 @@ cdef class CoreProtocol:
253253 elif mtype == b' Z' :
254254 # ReadyForQuery
255255 self ._parse_msg_ready_for_query()
256- if self .result_type == RESULT_FAILED:
257- self ._push_result()
258- else :
259- try :
260- buf = < WriteBuffer> next(self ._execute_iter)
261- except StopIteration :
262- self ._push_result()
263- except Exception as e:
264- self .result_type = RESULT_FAILED
265- self .result = e
266- self ._push_result()
267- else :
268- # Next iteration over the executemany() arg sequence
269- self ._send_bind_message(
270- self ._execute_portal_name, self ._execute_stmt_name,
271- buf, 0 )
256+ self ._push_result()
272257
273258 elif mtype == b' I' :
274259 # EmptyQueryResponse
@@ -687,6 +672,17 @@ cdef class CoreProtocol:
687672 if self .con_status != CONNECTION_OK:
688673 raise apg_exc.InternalClientError(' not connected' )
689674
675+ cdef WriteBuffer _build_parse_message(self , str stmt_name, str query):
676+ cdef WriteBuffer buf
677+
678+ buf = WriteBuffer.new_message(b' P' )
679+ buf.write_str(stmt_name, self .encoding)
680+ buf.write_str(query, self .encoding)
681+ buf.write_int16(0 )
682+
683+ buf.end_message()
684+ return buf
685+
690686 cdef WriteBuffer _build_bind_message(self , str portal_name,
691687 str stmt_name,
692688 WriteBuffer bind_data):
@@ -702,6 +698,25 @@ cdef class CoreProtocol:
702698 buf.end_message()
703699 return buf
704700
701+ cdef WriteBuffer _build_empty_bind_data(self ):
702+ cdef WriteBuffer buf
703+ buf = WriteBuffer.new()
704+ buf.write_int16(0 ) # The number of parameter format codes
705+ buf.write_int16(0 ) # The number of parameter values
706+ buf.write_int16(0 ) # The number of result-column format codes
707+ return buf
708+
709+ cdef WriteBuffer _build_execute_message(self , str portal_name,
710+ int32_t limit):
711+ cdef WriteBuffer buf
712+
713+ buf = WriteBuffer.new_message(b' E' )
714+ buf.write_str(portal_name, self .encoding) # name of the portal
715+ buf.write_int32(limit) # number of rows to return; 0 - all
716+
717+ buf.end_message()
718+ return buf
719+
705720 # API for subclasses
706721
707722 cdef _connect(self ):
@@ -752,12 +767,7 @@ cdef class CoreProtocol:
752767 self ._ensure_connected()
753768 self ._set_state(PROTOCOL_PREPARE)
754769
755- buf = WriteBuffer.new_message(b' P' )
756- buf.write_str(stmt_name, self .encoding)
757- buf.write_str(query, self .encoding)
758- buf.write_int16(0 )
759- buf.end_message()
760- packet = buf
770+ packet = self ._build_parse_message(stmt_name, query)
761771
762772 buf = WriteBuffer.new_message(b' D' )
763773 buf.write_byte(b' S' )
@@ -779,10 +789,7 @@ cdef class CoreProtocol:
779789 buf = self ._build_bind_message(portal_name, stmt_name, bind_data)
780790 packet = buf
781791
782- buf = WriteBuffer.new_message(b' E' )
783- buf.write_str(portal_name, self .encoding) # name of the portal
784- buf.write_int32(limit) # number of rows to return; 0 - all
785- buf.end_message()
792+ buf = self ._build_execute_message(portal_name, limit)
786793 packet.write_buffer(buf)
787794
788795 packet.write_bytes(SYNC_MESSAGE)
@@ -801,30 +808,75 @@ cdef class CoreProtocol:
801808
802809 self ._send_bind_message(portal_name, stmt_name, bind_data, limit)
803810
804- cdef _bind_execute_many(self , str portal_name, str stmt_name,
805- object bind_data):
806-
807- cdef WriteBuffer buf
808-
811+ cdef _execute_many_init(self ):
809812 self ._ensure_connected()
810813 self ._set_state(PROTOCOL_BIND_EXECUTE_MANY)
811814
812815 self .result = None
813816 self ._discard_data = True
814- self ._execute_iter = bind_data
815- self ._execute_portal_name = portal_name
816- self ._execute_stmt_name = stmt_name
817817
818- try :
819- buf = < WriteBuffer> next(bind_data)
820- except StopIteration :
821- self ._push_result()
822- except Exception as e:
823- self .result_type = RESULT_FAILED
824- self .result = e
818+ cdef _execute_many_writelines(self , str portal_name, str stmt_name,
819+ object bind_data):
820+ cdef:
821+ WriteBuffer packet
822+ WriteBuffer buf
823+ list buffers = []
824+
825+ if self .result_type == RESULT_FAILED:
826+ raise StopIteration (False )
827+
828+ while len (buffers) < _EXECUTE_MANY_BUF_NUM:
829+ packet = WriteBuffer.new()
830+
831+ while packet.len() < _EXECUTE_MANY_BUF_SIZE:
832+ try :
833+ buf = < WriteBuffer> next(bind_data)
834+ except StopIteration :
835+ if packet.len() > 0 :
836+ buffers.append(packet)
837+ if len (buffers) > 0 :
838+ self ._writelines(buffers)
839+ raise StopIteration (True )
840+ else :
841+ raise StopIteration (False )
842+ except Exception as ex:
843+ raise StopIteration (ex)
844+ packet.write_buffer(
845+ self ._build_bind_message(portal_name, stmt_name, buf))
846+ packet.write_buffer(
847+ self ._build_execute_message(portal_name, 0 ))
848+ buffers.append(packet)
849+ self ._writelines(buffers)
850+
851+ cdef _execute_many_done(self , bint data_sent):
852+ if data_sent:
853+ self ._write(SYNC_MESSAGE)
854+ else :
825855 self ._push_result()
856+
857+ cdef _execute_many_fail(self , object error):
858+ cdef WriteBuffer buf
859+
860+ self .result_type = RESULT_FAILED
861+ self .result = error
862+
863+ # We shall rollback in an implicit transaction to prevent partial
864+ # commit, while do nothing in an explicit transaction and leaving the
865+ # error to the user
866+ if self .is_in_transaction():
867+ self ._execute_many_done(True )
826868 else :
827- self ._send_bind_message(portal_name, stmt_name, buf, 0 )
869+ # Here if the implicit transaction is in `ignore_till_sync` mode,
870+ # the `ROLLBACK` will be ignored and `Sync` will restore the state;
871+ # or else the implicit transaction will be rolled back with a
872+ # warning saying that there was no transaction, but rollback is
873+ # done anyway, so we could ignore this warning.
874+ buf = self ._build_parse_message(' ' , ' ROLLBACK' )
875+ buf.write_buffer(self ._build_bind_message(
876+ ' ' , ' ' , self ._build_empty_bind_data()))
877+ buf.write_buffer(self ._build_execute_message(' ' , 0 ))
878+ buf.write_bytes(SYNC_MESSAGE)
879+ self ._write(buf)
828880
829881 cdef _execute(self , str portal_name, int32_t limit):
830882 cdef WriteBuffer buf
@@ -834,10 +886,7 @@ cdef class CoreProtocol:
834886
835887 self .result = []
836888
837- buf = WriteBuffer.new_message(b' E' )
838- buf.write_str(portal_name, self .encoding) # name of the portal
839- buf.write_int32(limit) # number of rows to return; 0 - all
840- buf.end_message()
889+ buf = self ._build_execute_message(portal_name, limit)
841890
842891 buf.write_bytes(SYNC_MESSAGE)
843892
@@ -920,6 +969,9 @@ cdef class CoreProtocol:
920969 cdef _write(self , buf):
921970 raise NotImplementedError
922971
972+ cdef _writelines(self , list buffers):
973+ raise NotImplementedError
974+
923975 cdef _decode_row(self , const char * buf, ssize_t buf_len):
924976 pass
925977
0 commit comments