本文主要是介绍使用estimator结构训练tf模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
一、使用estimator训练模型的流程
1、构建model_fn
def my_metric_fn(labels, predictions):return {'accuracy': tf.metrics.accuracy(labels, predictions)}def model_fn(features, labels, mode, params):""" TODO: 模型函数必须有这四个参数:param features: # 输入的特征数据:param labels: # 输入的标签数据:param mode: # train、evaluate或predict:param params: #超参数,对应Estimator传来的参数:return: TPUEstimatorSpec类型的对象"""eval_metrics=(my_metric_fn, [labels, predictions])output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode, # "train" or "eval" or "predict"loss=total_loss, # double类型eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) # None or funreturn output_spec
2、定义estimator
run_config = tf.contrib.tpu.RunConfig(cluster=tpu_cluster_resolver,master=FLAGS.master,model_dir=FLAGS.output_dir,save_checkpoints_steps=FLAGS.save_checkpoints_steps,keep_checkpoint_max=FLAGS.keep_checkpoint_max,tf_random_seed=FLAGS.random_seed,tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=FLAGS.save_checkpoints_steps,num_shards=FLAGS.num_tpu_cores,per_host_input_for_training=is_per_host))# 自定义估算器
estimator = tf.contrib.tpu.TPUEstimator(use_tpu=FLAGS.use_tpu,model_fn=model_fn, # 模型函数config=run_config, # 设置参数对象train_batch_size=FLAGS.train_batch_size,eval_batch_size=FLAGS.eval_batch_size,predict_batch_size=FLAGS.predict_batch_size)
3、训练模型
def train_input_fn(params):batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return destimator.train(input_fn=train_input_fn, max_steps=next_checkpoint)
4、验证模型
def eval_input_fn(params): # 部分代码 只看框架即可batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return d
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) # type:dict
for key in sorted(result.keys()):log_info = " %s = %s"%(key, str(result[key]))
5、测试模型
def predict_input_fn(params): # 部分代码 只看框架即可batch_size = params["batch_size"]d = tf.data.TFRecordDataset(input_file)if is_training:d = d.repeat()d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))d = d.apply(tf.data.experimental.map_and_batch(lambda record: _decode_record(record, name_to_features),batch_size=batch_size,drop_remainder=drop_remainder))return d
result = estimator.predict(input_fn=predict_input_fn) # type:dict
for key in sorted(result.keys()):log_info = " %s = %s"%(key, str(result[key]))
二、使用estimator训练模型的样例
这篇关于使用estimator结构训练tf模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!