博客
关于我
ACNE04 读数据的过程
阅读量:798 次
发布时间:2023-04-17

本文共 1539 字,大约阅读时间需要 5 分钟。

交叉验证设置

通过交叉验证来评估模型性能,分别设置5个验证指标为'0','1','2','3','4'。对于每个验证指标,执行以下操作:

for cross_val_index in cross_val_lists:
log.write('\n\ncross_val_index: ' + cross_val_index + '\n\n')
if True:
trainval_test(cross_val_index, sigma=30 * 0.1, lam=6 * 0.1)

数据集文件路径

训练集和测试集的文件路径分别为:

TRAIN_FILE = './Classification/NNEW_trainval_' + cross_val_index + '.txt'
TEST_FILE = './Classification/NNEW_test_' + cross_val_index + '.txt'

数据集处理

对训练集和测试集分别进行数据增强处理,训练集采用随机裁剪、水平翻转等变换,测试集则仅进行缩放和标准化处理。

dset_train = dataset_processing.DatasetProcessing(
DATA_PATH, TRAIN_FILE, transform=transforms.Compose([
transforms.Scale((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
RandomRotate(rotation_range=20),
normalize,
]))
dset_test = dataset_processing.DatasetProcessing(
DATA_PATH, TEST_FILE, transform=transforms.Compose([
transforms.Scale((224, 224)),
transforms.ToTensor(),
normalize,
]))

批量数据加载

使用DataLoader封装训练集和测试集,分别设置不同的批次大小,训练集采用随机洗牌,测试集则不洗牌。

train_loader = DataLoader(dset_train, 
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=False)
test_loader = DataLoader(dset_test,
batch_size=BATCH_SIZE_TEST,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=False)

训练流程

每个epoch遍历所有批次,完成一次完整的训练循环。

转载地址:http://nvgfk.baihongyu.com/

你可能感兴趣的文章