From 86271f583db8c31d679e2cf9180a879f46222309 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Vieille?= Date: Thu, 23 Aug 2018 18:52:45 +0200 Subject: [PATCH] features for horizontal - vertical --- src/features/build_features.py | 0 src/features/build_tfrecords.py | 23 +++++++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) delete mode 100644 src/features/build_features.py diff --git a/src/features/build_features.py b/src/features/build_features.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/features/build_tfrecords.py b/src/features/build_tfrecords.py index 1320e68..f5c4310 100644 --- a/src/features/build_tfrecords.py +++ b/src/features/build_tfrecords.py @@ -25,8 +25,16 @@ def main(): refs.rename(columns = lambda x: "label" + str(x) if x != 'image' else x, inplace = True) refs = refs[refs['image'].isin(imgs_list)] + # try to learn vertical / horizontal + refs = refs.query('label1 == 1 | label2 == 1') + + refs = refs.assign(train = pd.Series(np.random.binomial(n = 1, p = 0.6, size = len(refs.index))).values) + + refs.reset_index(inplace = True) + # initiate tf file writer - writer = tf.python_io.TFRecordWriter(train_filename) + train_writer = tf.python_io.TFRecordWriter(train_filename) + test_writer = tf.python_io.TFRecordWriter(test_filename) for i, row in refs.iterrows(): img_full_path = os.path.join(imgs_path, row['image']) @@ -38,9 +46,14 @@ def main(): })) logger.info("Creation TFRECORDS {}/{} : {}".format(i+1, refs.index.size, img_full_path)) - writer.write(example.SerializeToString()) + + if row['train'] == 1: + train_writer.write(example.SerializeToString()) + else: + test_writer.write(example.SerializeToString()) - writer.close() + train_writer.close() + test_writer.close() def _int64_feature(value): @@ -51,7 +64,8 @@ def _bytes_feature(value): def loadImg(path): - raw = np.asarray(Image.open(path)).tostring() + raw = np.asarray(Image.open(path)) / 255 + raw = raw.tostring() return path, raw @@ -66,5 +80,6 @@ if __name__ == '__main__': imgs_list = os.listdir(imgs_path) refs_path = os.path.join(str(project_dir), 'data', 'external', 'refs', 'references_labels.csv') train_filename = os.path.join(str(project_dir), 'data', 'processed', 'train.tfrecords') + test_filename = os.path.join(str(project_dir), 'data', 'processed', 'test.tfrecords') main()