温馨提示×

温馨提示×

您好,登录后才能下订单哦!

密码登录×
登录注册×
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》

python的tf.train.batch函数怎么用

发布时间:2022-05-05 09:29:49 来源:亿速云 阅读:192 作者:iii 栏目:开发技术

这篇文章主要介绍“python的tf.train.batch函数怎么用”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“python的tf.train.batch函数怎么用”文章能帮助大家解决问题。

tf.train.batch函数

tf.train.batch(     tensors,     batch_size,     num_threads=1,     capacity=32,     enqueue_many=False,     shapes=None,     dynamic_pad=False,     allow_smaller_final_batch=False,     shared_name=None,     name=None )

其中:

1、tensors:利用slice_input_producer获得的数据组合。

2、batch_size:设置每次从队列中获取出队数据的数量。

3、num_threads:用来控制线程的数量,如果其值不唯一,由于线程执行的特性,数据获取可能变成乱序。

4、capacity:一个整数,用来设置队列中元素的最大数量

5、allow_samller_final_batch:当其为True时,如果队列中的样本数量小于batch_size,出队的数量会以最终遗留下来的样本进行出队;当其为False时,小于batch_size的样本不会做出队处理。

6、name:名字

测试代码

1、allow_samller_final_batch=True

import pandas as pd import numpy as np import tensorflow as tf # 生成数据 def generate_data():     num = 18     label = np.arange(num)     return label # 获取数据 def get_batch_data():     label = generate_data()     input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2)     label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=True)     return label_batch # 数据组 label = get_batch_data() sess = tf.Session() # 初始化变量 sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # 初始化batch训练的参数 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess,coord) try:     while not coord.should_stop():         # 自动获取下一组数据         l = sess.run(label)         print(l) except tf.errors.OutOfRangeError:     print('Done training') finally:     coord.request_stop() coord.join(threads) sess.close()

运行结果为:

[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17  0  1]
[2 3 4 5 6]
[ 7  8  9 10 11]
[12 13 14 15 16]
[17]
Done training

2、allow_samller_final_batch=False

相比allow_samller_final_batch=True,输出结果少了[17]

import pandas as pd import numpy as np import tensorflow as tf # 生成数据 def generate_data():     num = 18     label = np.arange(num)     return label # 获取数据 def get_batch_data():     label = generate_data()     input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2)     label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False)     return label_batch # 数据组 label = get_batch_data() sess = tf.Session() # 初始化变量 sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # 初始化batch训练的参数 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess,coord) try:     while not coord.should_stop():         # 自动获取下一组数据         l = sess.run(label)         print(l) except tf.errors.OutOfRangeError:     print('Done training') finally:     coord.request_stop() coord.join(threads) sess.close()

运行结果为:

[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17  0  1]
[2 3 4 5 6]
[ 7  8  9 10 11]
[12 13 14 15 16]
Done training

关于“python的tf.train.batch函数怎么用”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识,可以关注亿速云行业资讯频道,小编每天都会为大家更新不同的知识点。

向AI问一下细节

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

AI