# Copyright (c) OpenMMLab. All rights reserved. from torch.utils.data import random_split def split_dataset(dataset): dset_length = len(dataset) first_dset_length = dset_length // 2 second_dset_length = dset_length - first_dset_length split_tuple = (first_dset_length, second_dset_length) first_dset, second_dset = random_split(dataset, split_tuple) first_dset.CLASSES = dataset.CLASSES second_dset.CLASSES = dataset.CLASSES return [first_dset, second_dset]