mirror of
https://github.com/prise6/smart-iss-posts
synced 2024-05-21 14:56:32 +02:00
features for horizontal - vertical
This commit is contained in:
parent
ee6f60470d
commit
86271f583d
|
@ -25,8 +25,16 @@ def main():
|
|||
refs.rename(columns = lambda x: "label" + str(x) if x != 'image' else x, inplace = True)
|
||||
refs = refs[refs['image'].isin(imgs_list)]
|
||||
|
||||
# try to learn vertical / horizontal
|
||||
refs = refs.query('label1 == 1 | label2 == 1')
|
||||
|
||||
refs = refs.assign(train = pd.Series(np.random.binomial(n = 1, p = 0.6, size = len(refs.index))).values)
|
||||
|
||||
refs.reset_index(inplace = True)
|
||||
|
||||
# initiate tf file writer
|
||||
writer = tf.python_io.TFRecordWriter(train_filename)
|
||||
train_writer = tf.python_io.TFRecordWriter(train_filename)
|
||||
test_writer = tf.python_io.TFRecordWriter(test_filename)
|
||||
|
||||
for i, row in refs.iterrows():
|
||||
img_full_path = os.path.join(imgs_path, row['image'])
|
||||
|
@ -38,9 +46,14 @@ def main():
|
|||
}))
|
||||
|
||||
logger.info("Creation TFRECORDS {}/{} : {}".format(i+1, refs.index.size, img_full_path))
|
||||
writer.write(example.SerializeToString())
|
||||
|
||||
if row['train'] == 1:
|
||||
train_writer.write(example.SerializeToString())
|
||||
else:
|
||||
test_writer.write(example.SerializeToString())
|
||||
|
||||
writer.close()
|
||||
train_writer.close()
|
||||
test_writer.close()
|
||||
|
||||
|
||||
def _int64_feature(value):
|
||||
|
@ -51,7 +64,8 @@ def _bytes_feature(value):
|
|||
|
||||
def loadImg(path):
|
||||
|
||||
raw = np.asarray(Image.open(path)).tostring()
|
||||
raw = np.asarray(Image.open(path)) / 255
|
||||
raw = raw.tostring()
|
||||
|
||||
return path, raw
|
||||
|
||||
|
@ -66,5 +80,6 @@ if __name__ == '__main__':
|
|||
imgs_list = os.listdir(imgs_path)
|
||||
refs_path = os.path.join(str(project_dir), 'data', 'external', 'refs', 'references_labels.csv')
|
||||
train_filename = os.path.join(str(project_dir), 'data', 'processed', 'train.tfrecords')
|
||||
test_filename = os.path.join(str(project_dir), 'data', 'processed', 'test.tfrecords')
|
||||
|
||||
main()
|
||||
|
|
Loading…
Reference in a new issue