tensorflow之object_detection模块
切换当前目录至models文件夹(请在此路径下操作,不然出现许多多多多多多…错误)
当前使用版本为Fix ML Engine Dashboard link (#1599)
Installation
适当修改
Add Libraries to PYTHONPATH
对应 linux下添加slim文件夹到环境变量
1
2
3 # From tensorflow/models/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
直接将slim文件夹添加到python path中(推荐使用),pycharm可参考http://blog.csdn.net/wh357589873/article/details/53204024
或进行如下代替操作(不推荐)
trainer.py
中from deployment import model_deploy
改为from slim.deployment import model_deploy
,deployment在slim文件夹下object_detection\models文件夹下的
*_feature_extractor.py
中from nets
改为from slim.nets
,nets在slim文件夹下slim/nets文件夹下的inception_utils.py和resnet_utils.py中
from nets
改为from slim.nets
,nets在slim文件夹下
Protobuf Compilation
1
2
3 # From tensorflow/models/
protoc object_detection/protos/*.proto --python_out=.
用protoc生成protos文件夹下的所有.proto对应的pb2文件
使用protoc生成
string_int_label_map_pb2.py
1
protoc object_detection/protos/string_int_label_map.proto --python_out=.
注:protoc代码内并没有提供,需自行下载,注意下载3.0以上的,生成python3以上代码
下载地址:http://repo1.maven.org/maven2/com/google/protobuf/protoc/
1
2
3
4from object_detection.protos import input_reader_pb2
from object_detection.protos import model_pb2
from object_detection.protos import pipeline_pb2
from object_detection.protos import train_pb2对应地,在protos文件夹中生成,而生成 model_pb2时需要生成ssd_pb2和faster_rcnn_pb2,生成ssd_pb2又需要如下pb2文件(如下为在object_detection文件夹下操作的提醒,也说明了
ssd_pb2.py
和faster_rcnn_pb2.py
关联的一些文件)1
2
3
4
5
6
7
8
9object_detection/protos/anchor_generator.proto: File not found.
object_detection/protos/box_coder.proto: File not found.
object_detection/protos/box_predictor.proto: File not found.
object_detection/protos/hyperparams.proto: File not found.
object_detection/protos/image_resizer.proto: File not found.
object_detection/protos/matcher.proto: File not found.
object_detection/protos/losses.proto: File not found.
object_detection/protos/post_processing.proto: File not found.
object_detection/protos/region_similarity_calculator.proto: File not found.生成
faster_rcnn_pb2
文件需要如下pb2文件1
2
3
4
5
6object_detection/protos/anchor_generator.proto: File not found.
object_detection/protos/box_predictor.proto: File not found.
object_detection/protos/hyperparams.proto: File not found.
object_detection/protos/image_resizer.proto: File not found.
object_detection/protos/losses.proto: File not found.
object_detection/protos/post_processing.proto: File not found.依次生成上述pb2文件
anchor_generator_pb2.py
关联grid_anchor_generator_pb2
与ssd_anchor_generator_pb2
box_coder_pb2.py
关联faster_rcnn_box_coder_pb2
mean_stddev_box_coder_pb2
square_box_coder_pb2
matcher_pb2.py
关联argmax_matcher_pb2
bipartite_matcher_pb2
pipeline_pb2.py
关联eval_pb2
train_pb2.py
关联optimizer_pb2
总之,就是生成protos文件夹下的所有.proto对应的pb2文件
Testing the Installation
1
2 > python object_detection/builders/model_builder_test.py
>
运行model_builder_test.py文件,结果如下
1
2
3----------------------------------------------------------------------
Ran 6 tests in 0.003s
OK
Configuring an object detection pipeline
总览
配置文件分为五个部分
model
configuration,定义了训练什么类型的模型(如meta-architecture, feature extractor)train_config
,决定哪些参数应该被用来训练模型参数(如SGD参数,输入预处理和特征提取初始化值)。eval_config
,决定了哪些指标将被进行评估报告(目前仅支持PASCAL VOC指标)train_input_config
, 定义了模型训练时用了哪些数据集eval_input_config
, 定义了模型进行评估是哪些数据集。通常这应该与训练输入数据集不同
配置文件架构如下:
1 | model { |
可参考samples/configs文件夹下的config文件,如
1 | # Faster R-CNN with Resnet-101 (v1), configured for Pascal VOC Dataset. |
模型参数初始化(预训练模型)
虽然可选,但强烈建议用户利用其他对象检测检查点(checkpoints)。从头开始训练一个目标检测器可能需要几天时间。为加快训练过程,建议用户从预先存在的对象分类或检测点重新使用特征提取器参数。train_config
提供了两个字段指定预先存在的检查点:fine_tune_checkpoint
和from_detection_checkpoint
。fine_tune_checkpoint
应提供一个到现有检查点的路径(如:“/usr/home/username/checkpoint/model.ckpt-#####”。 from_detection_checkpoint
是一个布尔值。如果为false,则假定检查点来自对象分类检查点。请注意,从检测点开始通常会导致比分类检查点更快的训练作业。提供的检查点列表可以在这里找到。
输入预处理
train_config
中data_augmentation_options
可用于指定的训练数据是如何被修改。此字段是可选的。
SGD参数
train_config
剩余的参数是梯度下降的超参数。请注意,这些配置文件中提供的最佳学习率可能取决于训练设置的具体情况(例如,迭代次数,gpu类型)。
配置评估器
目前的评估固定在由PASCAL VOC挑战定义的生成指标上。参数eval_config
设置为合理的默认值,通常不需要配置
Preparing Inputs
生成PASCAL VOC TFRecord文件
create_pascal_tf_record.py
和label_map_util.py
文件中存在编解码的错误。create_pascal_tf_record.py
在#1614的提交上得到了解决,utils文件夹下的label_map_util.py
在第104行代码1
label_map_string = fid.read()
修改为
1
label_map_string = fid.read().decode('utf-8')
见下图显示
然后运行
create_pascal_tf_record.py
,以生成pascal2012训练用的record为例,其后参数如下1
--data_dir=F:/Database/VOC/VOCtrainval_11-May-2012/VOCdevkit --year=VOC2012 --set=train --output_path=pascal_train.record
由于运行目录需要是models文件夹,可用pycharm打开至models文件夹,在Run选项下的Edit Configurations下设置参数
生成验证所需的record
1
--data_dir=F:/Database/VOC/VOCtrainval_11-May-2012/VOCdevkit --year=VOC2012 --set=val --output_path=pascal_val.record
这样在
tensorflow/models/object_detection
目录下生成了两个TFRecord
文件pascal_train.record
和pascal_val.record
PASCAL VOC数据集的label map 可以在
data/pascal_label_map.pbtxt
找到Train for VOC
编码问题修改
- train.py中的get_configs_from_pipeline_file()函数内text_format.Merge(f.read(), pipeline_config)改为text_format.Merge(f.read().decode(‘utf-8’), pipeline_config)
python2到python3的修改
进行如下修改
https://github.com/tensorflow/models/pull/1593/files
训练参数设置
1 | --logtostderr --pipeline_config_path=./my_model/ssd_inception_v2_head.config --train_dir=F:/models/bus |
1 | --logtostderr --pipeline_config_path=./my_model/faster_rcnn_resnet101_voc07.config --train_dir=F:/models/passenger_head/rfcn_resnet50 |
注:第一次训练需要将
faster_rcnn_resnet101_voc07.config
中的from_detection_checkpoint: true注释掉或设为False,否则报错,之后可以使用该参数进行继续训练,同时注释fine_tune_checkpoint的话,则不使用预训练模型
eval
https://github.com/tensorflow/models/pull/1758/commits/e9606bc69ae9e8a401db1cf5920b24d8408b0c02
1 | --logtostderr --eval_dir=F:/log --pipeline_config_path=./my_model/ssd_inception_v2_head.config --checkpoint_dir=F:\models\bus\model.ckpt-369 |
Test
本部分查看object_detection_tutorial.ipynb
文件(在pycharm中打开)
用export_inference_graph.py
将生成的模型转换为.pb格式模型
参数如下
1 | --input_type image_tensor \ |
export_inference_graph.py
文件中需要改一个编码问题
1
2 >93 text_format.Merge(f.read(), pipeline_config)
>
改为
1
2 >93 text_format.Merge(f.read().decode('utf-8'), pipeline_config)
>
训练我们的数据集
首先需要生成我们数据集对应的TFRecord文件,代码如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import io
import PIL.Image
import tensorflow as tf
from object_detection.utils import dataset_util
def read_label_file(label_file_path):
object = []
with open(label_file_path) as label_file:
raw_lines = [line.strip() for line in label_file.readlines()]
for raw_line in raw_lines:
class_num, c_x, c_y, w, h = [float(e) for e in raw_line.split(" ")]
x1 = (c_x - w / 2)
y1 = (c_y - h / 2)
x2 = (c_x + w / 2)
y2 = (c_y + h / 2)
x1 = max(x1, 0)
y1 = max(y1, 0)
x2 = min(x2, 1)
y2 = min(y2, 1)
class_num = int(class_num)
object.append([class_num, x1, y1, x2, y2])
return object
def main():
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
poses = []
difficult_obj = []
# image_idx = 0
image_list_path = r"F:\Database\data_set\train\train_bk.txt"
# image_list_path = r"F:\Database\data_set\validate\val.txt"
writer = tf.python_io.TFRecordWriter("F:\tensorflow\tfrecord\train.record")
# writer = tf.python_io.TFRecordWriter("F:\tensorflow\tfrecord\val.record")
with open(image_list_path, "r") as file:
image_list = [line.strip().split() for line in file.readlines()]
for img_path in image_list:
# print(img_path[0])
with tf.gfile.GFile(img_path[0], 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)
# image = PIL.Image.open(img_path[0])
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
key = hashlib.sha256(encoded_jpg).hexdigest()
width = image.width
height = image.height
# print(width, height)
label_path = img_path[0].replace("images", "labels").replace("jpg", "txt")
object = read_label_file(label_path)
# print(len(object))
for obj_num in range(0, len(object)):
xmin.append(objectobj_num)
ymin.append(objectobj_num)
xmax.append(objectobj_num)
ymax.append(objectobj_num)
classes_text.append('head'.encode('utf8'))
classes.append(objectobj_num + 1) # 类别从1开始
difficult_obj.append(0)
truncated.append(1)
poses.append('Unspecified'.encode('utf8'))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
img_path[0].strip().split('/')[-1].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
img_path[0].strip().split('/')[-1].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
}))
# image_idx +=1
# if image_idx == 1:
# print(example)
writer.write(example.SerializeToString())
writer.close()
if name == 'main':
main()
注:类别号需要从1开始,由于我们的标注类别为0,所以
classes.append(object[obj_num][0] + 1)
这有个+1操作。另外,对比voc的TFRecord文件内容,由于我们的数据集不存在其他参数,所以都设置为一样(参考voc的第1个tf_example输出)