大数跨境
0
0

Pytorch 数据流中常见Trick总结

Pytorch 数据流中常见Trick总结 极市平台
2021-12-07
0
导读:Pytorch建模过程中的思考
↑ 点击蓝字 关注极市平台

作者丨zlhroughlove@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/441317369
编辑丨极市平台

极市导读

 

本文对使用Pytorch建模过程中的一些思考以及遇到的问题做了一个总结,希望能给各位读者一个比较通用的模板~ >>加入极市CV技术交流群,走在计算机视觉的最前沿

前言

在使用Pytorch建模时,常见的流程为先写Model,再写Dataset,最后写Trainer。Dataset 是整个项目开发中投入时间第二多,也是中间关键的步骤。往往需要事先对于其设计有明确的思考,不然可能会因为Dataset的一些问题又要去调整Model,Trainer。本文将目前开发中的一些思考以及遇到的问题做一个总结,提供给各位读者一个比较通用的模版,抛砖引玉~

一、Dataset的定义

from torch.utils.data import Dataset, DataLoader, RandomSampler

对于不同类型的建模任务,模型的输入各不相同。自然语言,多模态,点击率预估,往往这些场景输入模型的数据并不是来自于单一文件,而且可能无法全部存入内存。Dataset需要整合项目的数据,对于单条样本涉及到的数据做一个提取与归纳。不但如此,项目可能还涉及到多种模型,任务的训练。Dataset需要为不同的模型以及训练任务提供不同的单条样本输入,作为一个数据生成器,把后续模型训练任务需要的所有基础数据,标签全返回了。所以往往我们可以定义一个BaseDataset类,继承torch.utils.data.Dataset,这个类可以初始化一些文件路径,配置等。后面不同的模型训练任务定义相应的Dataset类继承BaseDataset。

Dataset通用的结构为:

class BaseDataset(Dataset):

    def __init__(self, config):
        self.config = config
        if os.path.isfile(config.file_path) is False:
            raise ValueError(f"Input file path {config.file_path} not found")
        logger.info(f"Creating features from dataset file at {config.file_path}")
        # 一次性全读进内存
        self.data = joblib.load(config.file_path)
        self.nums = len(self.data)

    def __len__(self):
        return self.nums

    def __getitem__(self, i) -> Dict[str, tensor]:
        sample_i = self.data[i]
        return {"f1":torch.tensor(sample_i["f1"]).long(),"f2":torch.tensor(sample_i["f2"]).long(),torch.LongTensor([sample_i["label"]])}

如果无法全部读取进内存需要再__getitem__方法内构建数据,做自然语言则可以吧tokenizer初始化到该类中,在__getitem__方法内完成tokenizer。改方法的输出推荐做成字典形式。

对于不同的训练任务可以通过以下方法返回响应的数据生成器

def build_dataset(task_type, features, **kwargs):
    assert task_type in ['task1''task2'], 'task mismatch'

    if task_type == 'task1':
        dataset = task1Dataset(features))
    else:
        dataset = task2Dataset(features)

    return dataset

有时模型的训练任务需要做数据增强,对比学习,构造多种的预训练任务输入。Dataset的职能边界是提供一套基础的单样本数据输入生成器。如果是MLM任务,可以在Dataset内生成maskposition以及label。如果是在batch内的对比学习则应该在DataLoader生产batch数据后再进行。

二、DataLoader的定义

DataLoader的作用是对Dataset进行多进程高效地构建每个训练批次的数据。传入的数据可以认为是长度为batch大小的多个__getitem__ 方法返回的字典list。DataLoader的职能边界是根据Dataset提供的单条样本数据有选择的构建一个batch的模型输入数据。

其通常的结构为对Train,Valid,Test分别建立:

train_sampler = RandomSampler(train_dataset)
train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.train_batch_size,
                              sampler=train_sampler,
                              shuffle=(train_sampler is None)
                              collate_fn=None, # 一般不用设置
                              num_workers=4)

首先对于sampler 还有一种定义方式:

sampler = torch.utils.data.distributed.DistributedSampler(dataset)

至于batch内数据是否需要做shuffle也需要根据损失函数确定(对比学习慎用)

DataLoader会自动合并__getitem__ 方法返回的字典内每个key内每个tensor,在tensor的第0维度新增一个batch大小的维度。如果该方法返回的每条样本长度不同无法拼接,batchsize>1就会报错。但是又一些任务在还没有确定后续的批样本对应的任务时,Dataset可能返回的字典里每个key可能就是长度不同的tensor,甚至是list,这时候需要使用collate_fn参数告诉DataLoader如何取样。我们可以定义自己的函数来准确地实现想要的功能。

如果__getitem__方法返回的是tuple((list, list)) 可以使用:

def merge_sample(x):
    return zip(*x)

train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.train_batch_size,
                              sampler=train_sampler,
                              shuffle=(train_sampler is None)
                              collate_fn=merge_sample,
                              num_workers=4)

拼接数据,后续再做进一步处理。(此时list内数据还是不等长,无法转为tensor)

如果__getitem_方法返回的是Dict[str,tensor],自定义的collate_fn方法内需要实现:List[Dict[str,tensor(xx)]]->Dict[str,tensor(bs,xx)]的操作,pad_sequence过程也可以在自定义方法内实现。(总之collate_fn中不但可以处理不等长数据,还可以对一个batch的数据做精修。当然也可以在DataLoader之后再做修改batch内的数据。)

值得注意的是在cpu环境下,如果要自定义collate_fn,num_workers必须设置为0,不然就会有问题..

通过以下方式可以检查一下输入后续模型的数据是否已经是想要的格式

for step, batch_data in enumerate(train_loader):
    if step < 1:
        print(batch_data)
    else:
        break

之后数据将数据放入gpu device, 一个batch的数据进入device端后就与内存上的数据不再互相干扰。之后数据就可以喂给模型了:

for key in batch_data.keys():
    batch_data[key] = batch_data[key].to(device)
loss = model(**batch_data)

如果觉得有用,就请分享到朋友圈吧!

△点击卡片关注极市平台,获取最新CV干货

公众号后台回复“transformer”获取最新Transformer综述论文下载~


极市干货
课程/比赛:珠港澳人工智能算法大赛保姆级零基础人工智能教程
算法trick目标检测比赛中的tricks集锦从39个kaggle竞赛中总结出来的图像分割的Tips和Tricks
技术综述:一文弄懂各种loss function工业图像异常检测最新研究总结(2019-2020)


CV技术社群邀请函 #

△长按添加极市小助手
添加极市小助手微信(ID : cvmart4)

备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)


即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群


每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~



觉得有用麻烦给个在看啦~  
【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读8.7k
粉丝0
内容8.2k