Skip to content

Latest commit

 

History

History
30 lines (21 loc) · 1.36 KB

File metadata and controls

30 lines (21 loc) · 1.36 KB

数据读取

AGL 样本构建完成之后,目前数据以 csv 的形式存储,每条样本包含几个 column, 其中一个 column 存储 序列化后的 GraphFeature. 其他 column 可能是样本id, label,样本级别特征等信息。

基于Pytorch, AGL提供两个简单的 Dataset,读取这些csv文件,并构建模型所需的训练/验证/测试集。 假设你构建好的 PPI 训练接名称为 ppi_train.csv:

  • AGLTorchMapBasedDataset (map-style dataset)

     from agl.python.dataset.map_based_dataset import AGLTorchMapBasedDataset
     train_data_set = AGLTorchMapBasedDataset("/your_path_to/ppi_train.csv")
     print(train_data_set[0]) # 查看第一条数据
  • AGLIterableDataset (iterable-stype dataset)

    和上述 AGLTorchMapBasedDataset 使用方法类似,但需要指定batch_size (dataloader 中不要再设置batch_size)

    from agl.python.dataset.iterable_dataset import AGLIterableDataset
    train_data_set = AGLIterableDataset(file="/your_path_to/ppi_train_.csv")
    for data in iterable_dataset:
       print(f"data : {data}")
       break

不失一般性,在后面的样例中,我们都以AGLTorchMapBasedDataset作为读取数据的方式。

next 数据解析