@@ -776,227 +776,6 @@ ne_type quant_params_to_type(const quant_params& params) {
776776 }
777777 return NE_TYPE_F32;
778778}
779- size_t jblas_quantize (const float * f32ptr, void * dstpr, const quant_params params, int n, int k) {
780- using CompType = jblas::prologue::weight_comp::gemm::WeightCompType;
781- auto cd = jblas::utils::parallel::CpuDevice::getInstance ();
782- jblas::prologue::PackedWeight* packedw = NULL ;
783- auto type = CompType::S4_F32;
784- if (params.bits == 4 ) {
785- if (params.scale_dtype == " bf16" ) {
786- type = CompType::S4_Bf16;
787- } else {
788- type = CompType::S4_F32;
789- }
790- } else if (params.bits == 8 ) {
791- type = CompType::S8_F32;
792- } else {
793- return 0 ;
794- }
795- cd->setThreads (params.nthread );
796- if (params.bits == 4 ) {
797- if (params.compute_type == " int8" ) {
798- using GemmKernel = jblas::wrapper::gemm_default::weight_comp::avx512_vnni::GemmKernelDynamicQuantS4KBlock;
799- static GemmKernel kernel;
800- assert (cd->AVX512F ());
801- packedw = kernel.getWeightPtr ()->compressWeightTranspose (n, k, f32ptr, k, params.block_size , type);
802- } else if (params.compute_type == " fp32" ) {
803- using GemmKernel = jblas::wrapper::gemm_default::weight_comp::avx512f::GemmKernelS4KBlock;
804- static GemmKernel kernel;
805- assert (cd->AVX512F ());
806- packedw = kernel.getWeightPtr ()->compressWeightTranspose (n, k, f32ptr, k, params.block_size , type);
807- }
808- } else if (params.bits == 8 ) {
809- // TODO add 8bit quantization
810- }
811- assert (packedw != 0 );
812- auto size = packedw->getSerializedSize ();
813- packedw->serializeToBuffer (dstpr);
814- delete packedw;
815- return size;
816- }
817-
818- bool ne_common_quantize_0 (std::ifstream& finp, std::ofstream& fout, const quant_params params,
819- const std::vector<std::string>& to_quant, const std::vector<std::string>& to_skip) {
820- ne_type qtype = quant_params_to_type (params);
821- if (!ne_is_quantized (qtype)) {
822- fprintf (stderr, " %s: invalid quantization type %d (%s)\n " , __func__, qtype, ne_type_name (qtype));
823- return false ;
824- }
825-
826- size_t total_size_org = 0 ;
827- size_t total_size_new = 0 ;
828-
829- std::vector<float > work;
830-
831- std::vector<uint8_t > data_u8;
832- std::vector<ne_fp16_t > data_f16;
833- std::vector<float > data_f32;
834-
835- std::vector<int64_t > hist_all (1 << 4 , 0 );
836-
837- while (true ) {
838- int32_t n_dims;
839- int32_t length;
840- int32_t ttype;
841-
842- finp.read (reinterpret_cast <char *>(&n_dims), sizeof (n_dims));
843- finp.read (reinterpret_cast <char *>(&length), sizeof (length));
844- finp.read (reinterpret_cast <char *>(&ttype), sizeof (ttype));
845-
846- if (finp.eof ()) {
847- break ;
848- }
849-
850- int32_t nelements = 1 ;
851- int32_t ne[4 ] = {1 , 1 , 1 , 1 };
852- for (int i = 0 ; i < n_dims; ++i) {
853- finp.read (reinterpret_cast <char *>(&ne[i]), sizeof (ne[i]));
854- nelements *= ne[i];
855- }
856-
857- std::string name (length, 0 );
858- finp.read (&name[0 ], length);
859-
860- printf (" %64s - [%5d, %5d, %5d], type = %6s " , name.data (), ne[0 ], ne[1 ], ne[2 ], ne_type_name ((ne_type)ttype));
861-
862- bool quantize = false ;
863-
864- // check if we should quantize this tensor
865- for (const auto & s : to_quant) {
866- if (std::regex_match (name, std::regex (s))) {
867- quantize = true ;
868- break ;
869- }
870- }
871-
872- // check if we should skip this tensor
873- for (const auto & s : to_skip) {
874- if (std::regex_match (name, std::regex (s))) {
875- quantize = false ;
876- break ;
877- }
878- }
879-
880- // quantize only 2D tensors
881- quantize &= (n_dims == 2 );
882-
883- if (quantize) {
884- if (ttype != NE_TYPE_F32 && ttype != NE_TYPE_F16) {
885- fprintf (stderr, " %s: unsupported ttype %d (%s) for integer quantization\n " , __func__, ttype,
886- ne_type_name ((ne_type)ttype));
887- return false ;
888- }
889-
890- if (ttype == NE_TYPE_F16) {
891- data_f16.resize (nelements);
892- finp.read (reinterpret_cast <char *>(data_f16.data ()), nelements * sizeof (ne_fp16_t ));
893- data_f32.resize (nelements);
894- for (int i = 0 ; i < nelements; ++i) {
895- data_f32[i] = ne_fp16_to_fp32 (data_f16[i]);
896- }
897- } else {
898- data_f32.resize (nelements);
899- finp.read (reinterpret_cast <char *>(data_f32.data ()), nelements * sizeof (float ));
900- }
901-
902- ttype = qtype;
903- } else {
904- const int bpe = (ttype == 0 ) ? sizeof (float ) : sizeof (uint16_t );
905-
906- data_u8.resize (nelements * bpe);
907- finp.read (reinterpret_cast <char *>(data_u8.data ()), nelements * bpe);
908- }
909-
910- fout.write (reinterpret_cast <char *>(&n_dims), sizeof (n_dims));
911- fout.write (reinterpret_cast <char *>(&length), sizeof (length));
912- fout.write (reinterpret_cast <char *>(&ttype), sizeof (ttype));
913- for (int i = 0 ; i < n_dims; ++i) {
914- fout.write (reinterpret_cast <char *>(&ne[i]), sizeof (ne[i]));
915- }
916- fout.write (&name[0 ], length);
917-
918- if (quantize) {
919- work.resize (nelements); // for quantization
920-
921- size_t cur_size = 0 ;
922- std::vector<int64_t > hist_cur (1 << 4 , 0 );
923-
924- switch ((ne_type)ttype) {
925- case NE_TYPE_Q4_0: {
926- cur_size = ne_quantize_q4_0 (data_f32.data (), work.data (), nelements, ne[0 ], hist_cur.data ());
927- } break ;
928- case NE_TYPE_Q4_1: {
929- cur_size = ne_quantize_q4_1 (data_f32.data (), work.data (), nelements, ne[0 ], hist_cur.data ());
930- } break ;
931- case NE_TYPE_Q5_0: {
932- cur_size = ne_quantize_q5_0 (data_f32.data (), work.data (), nelements, ne[0 ], hist_cur.data ());
933- } break ;
934- case NE_TYPE_Q5_1: {
935- cur_size = ne_quantize_q5_1 (data_f32.data (), work.data (), nelements, ne[0 ], hist_cur.data ());
936- } break ;
937- case NE_TYPE_Q8_0: {
938- cur_size = ne_quantize_q8_0 (data_f32.data (), work.data (), nelements, ne[0 ], hist_cur.data ());
939- } break ;
940- case NE_TYPE_JBLAS: {
941- cur_size = jblas_quantize (data_f32.data (), work.data (), params, ne[1 ], ne[0 ]);
942- if (cur_size == 0 ) {
943- fprintf (stderr, " %s: unsupported jblas quantization parameters %d %s %s\n " , __func__, params.bits ,
944- params.alg .c_str (), params.compute_type .c_str ());
945- return false ;
946- }
947- } break ;
948- case NE_TYPE_F16:
949- case NE_TYPE_I8:
950- case NE_TYPE_I16:
951- case NE_TYPE_I32:
952- case NE_TYPE_Q8_1:
953- case NE_TYPE_COUNT: {
954- fprintf (stderr, " %s: unsupported quantization type %d (%s)\n " , __func__, ttype, ne_type_name ((ne_type)ttype));
955- return false ;
956- }
957- }
958-
959- fout.write (reinterpret_cast <char *>(work.data ()), cur_size);
960- total_size_new += cur_size;
961-
962- printf (" size = %8.2f MB -> %8.2f MB | hist: " , nelements * sizeof (float ) / 1024.0 / 1024.0 ,
963- cur_size / 1024.0 / 1024.0 );
964- for (int i = 0 ; i < (int )hist_cur.size (); ++i) {
965- hist_all[i] += hist_cur[i];
966- }
967-
968- for (int i = 0 ; i < (int )hist_cur.size (); ++i) {
969- printf (" %5.3f " , hist_cur[i] / (float )nelements);
970- }
971- printf (" \n " );
972- } else {
973- printf (" size = %8.3f MB\n " , data_u8.size () / 1024.0 / 1024.0 );
974- fout.write (reinterpret_cast <char *>(data_u8.data ()), data_u8.size ());
975- total_size_new += data_u8.size ();
976- }
977-
978- total_size_org += nelements * sizeof (float );
979- }
980-
981- printf (" %s: model size = %8.2f MB\n " , __func__, total_size_org / 1024.0 / 1024.0 );
982- printf (" %s: quant size = %8.2f MB | qtype = %d (%s)\n " , __func__, total_size_new / 1024.0 / 1024.0 , qtype,
983- ne_type_name (qtype));
984-
985- {
986- int64_t sum_all = 0 ;
987- for (int i = 0 ; i < (int )hist_all.size (); ++i) {
988- sum_all += hist_all[i];
989- }
990-
991- printf (" %s: hist: " , __func__);
992- for (int i = 0 ; i < (int )hist_all.size (); ++i) {
993- printf (" %5.3f " , hist_all[i] / (float )sum_all);
994- }
995- printf (" \n " );
996- }
997-
998- return true ;
999- }
1000779
1001780void console_init (console_state& con_st) {
1002781#if defined(_WIN32)
0 commit comments