caffe如何解析数据库
一、convert_imageset
使用caffe中的convert_imageset工具可以将原始图片转换成LevelDB或者Lmdb格式。转换方法如下:
$ convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME
-
ROOTFLOLDER
: 为图像集的根目录,必须带/
结尾 -
LISTFILE
: 文本文件,记录图像路径(该路径是ROOTFOLDER的相对路径)和标注,举例如下:train/321.jpg 0 train/352.jpg 1 train/337.jpg 2 train/345.jpg 3 train/339.jpg 4
-
DB_NAME
: 生成的数据库的名称,可以是相对路径 -
[FLAGS]
为可选参数,如下:gray
:bool类型,默认为false,如果设置为true,则代表将图像当做灰度图像来处理,否则当做彩色图像来处理shuffle
:bool类型,默认为false,如果设置为true,则代表将图像集中的图像的顺序随机打乱backend
:string类型,可取的值的集合为{“lmdb”, “leveldb”},默认为”lmdb”,代表采用何种形式来存储转换后的数据resize_width
:int32的类型,默认值为0,如果为非0值,则代表图像的宽度将被resize成resize_widthresize_height
:int32的类型,默认值为0,如果为非0值,则代表图像的高度将被resize成resize_heightcheck_size
:bool类型,默认值为false,如果该值为true,则在处理数据的时候将检查每一条数据的大小是否相同encoded
:bool类型,默认值为false,如果为true,代表将存储编码后的图像,具体采用的编码方式由参数encode_type指定encode_type
:string类型,默认值为”“,用于指定用何种编码方式存储编码后的图像,取值为编码方式的后缀(如png、jpg等等)
举例:
$ ./build/tools/convert_imageset \
--shuffle \
--resize_height=256 \
--resize_width=256 \
/home/xxx/caffe/data/mynet/ \
examples/mynet/train.txt \
examples/mynet/mynet_train_lmdb
二、Datum
1、定义
在caffe.proto
中的定义如下:
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
// the actual image data, in bytes
optional bytes data = 4;
optional int32 label = 5;
// Optionally, the datum could also hold float data.
repeated float float_data = 6;
// If true data contains an encoded image that need to be decoded
optional bool encoded = 7 [default = false];
}
其中channels/height/width分别对应C/H/W;保存的数据可以是byte类型(data),或者是float类型(float_data);label对应标签号,通常从0开始;encoded表示数据是否被译码(如jpg/png等等),如果有则需要解码后使用,默认是false,表示纯粹的3维数据。
2、初始化
一张图片的过程如下(经过精简,默认存入未译码数据):
bool ReadImageToDatum(const string& filename, ...) {
cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);
datum->set_channels(cv_img.channels());
datum->set_height(cv_img.rows);
datum->set_width(cv_img.cols);
datum->clear_data();
datum->clear_float_data();
datum->set_encoded(false);
int datum_channels = datum->channels();
int datum_height = datum->height();
int datum_width = datum->width();
int datum_size = datum_channels * datum_height * datum_width;
std::string buffer(datum_size, ' ');
for (int h = 0; h < datum_height; ++h) {
const uchar* ptr = cv_img.ptr<uchar>(h);
int img_index = 0;
for (int w = 0; w < datum_width; ++w) {
for (int c = 0; c < datum_channels; ++c) {
int datum_index = (c * datum_height + h) * datum_width + w;
buffer[datum_index] = static_cast<char>(ptr[img_index++]);
}
}
}
datum->set_data(buffer);
datum->set_label(label);
return true;
}
三、数据库的生成
数据库存入的是Datum序列化后的数据,以及索引。精简后的代码如下:
int main(...) {
scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend)); // lmdb 或者 leveldb
db->Open(argv[3], db::NEW); // 创建数据库
scoped_ptr<db::Transaction> txn(db->NewTransaction());
std::ifstream infile(argv[2]); // 打开LISTFILE文件
while (std::getline(infile, line)) { // 读取文件中的图片路径和标签
pos = line.find_last_of(' ');
label = atoi(line.substr(pos + 1).c_str());
lines.push_back(std::make_pair(line.substr(0, pos), label));
}
for (int line_id = 0; line_id < lines.size(); ++line_id) { // 每个图片处理
ReadImageToDatum(root_folder + lines[line_id].first, ...); // 将1张图片转换成Datum
string key_str = "0000XXXX_" + lines[line_id].first; //索引,XXXX表示line_id
datum.SerializeToString(&out); //Datum序列化
txn->Put(key_str, out); //存入数据库
}
txn->Commit(); // 数据库写入磁盘
return 0;
}
四、数据库的解析
解析数据库在DataLayer中实现。精简后的代码如下:
// 只读方式打开数据库
db_.reset(db::GetDB(param.data_param().backend()));
db_->Open(param.data_param().source(), db::READ);
cursor_.reset(db_->NewCursor());
// 按batch数量load Datum
Datum datum;
for (int item_id = 0; item_id < batch_size; ++item_id) {
datum.ParseFromString(cursor_->value()); // 反序列化成Datum
// 按偏移存入Datum转换后的数据,到Blob<Dtype>中
int offset = batch->data_.offset(item_id);
Dtype* top_data = batch->data_.mutable_cpu_data();
this->transformed_data_.set_cpu_data(top_data + offset);
this->data_transformer_->Transform(datum, &(this->transformed_data_));
// 下一个
cursor_->Next();
offset_++;
}
其中TransForm代码,精简如下:
const string& data = datum.data();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();
Dtype datum_element;
int top_index, data_index;
for (int c = 0; c < datum_channels; ++c) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;
top_index = (c * height + h) * width + w;
if (has_uint8) { // 根据类型读取数据
datum_element = datum.data[data_index];
} else {
datum_element = datum.float_data(data_index);
}
if (has_mean_file) { // 如果有mean文件,要减去mean数据
transformed_data[top_index] =
(datum_element - mean[data_index]) * scale;
} else { //对数据进行scale处理
transformed_data[top_index] = datum_element * scale;
}
}}}