Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fedlab_benchmarks/leaf/README_zh_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ bash preprocess.sh -s niid --sf 0.05 -k 0 -t sample
cd fedlab_benchmarks/datasets/data/shakespeare
bash preprocess.sh -s niid --sf 0.2 -k 0 -t sample
# bash preprocess.sh -s niid --sf 1.0 -k 0 -t sample # get 660 users (with default --tf 0.9)
# bash preprocess.sh -s niid --sf 1.0 -k 0 -t user # get 1129 users (with default --tf 0.9)
# bash preprocess.sh -s iid --iu 1.0 --sf 1.0 -k 0 -t sample # get all 1129 users

cd fedlab_benchmarks/datasets/data/sent140
Expand Down
16 changes: 6 additions & 10 deletions fedlab_benchmarks/leaf/dataset/sent140_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def __init__(self, client_id: int, client_str: str, data: list, targets: list,
self.data_token = []
self.data_tokens_tensor = []
self.targets_tensor = []
self.vocab = None
self.tokenizer = tokenizer if tokenizer else Tokenizer()
self.fix_len = None

self._process_data_target()
if is_to_tokens:
Expand Down Expand Up @@ -76,16 +74,14 @@ def encode(self, vocab: 'Vocab', fix_len: int):
if len(self.data_tokens_tensor) > 0:
self.data_tokens_tensor.clear()
self.targets_tensor.clear()
self.vocab = vocab
self.fix_len = fix_len
pad_idx = self.vocab.get_index('<pad>')
pad_idx = vocab.get_index('<pad>')
assert self.data_token is not None
for tokens in self.data_token:
self.data_tokens_tensor.append(self.__encode_tokens(tokens, pad_idx))
self.data_tokens_tensor.append(self.__encode_tokens(tokens, vocab, pad_idx, fix_len))
for target in self.targets:
self.targets_tensor.append(torch.tensor(target))

def __encode_tokens(self, tokens, pad_idx) -> torch.Tensor:
def __encode_tokens(self, tokens, vocab, pad_idx, fix_len) -> torch.Tensor:
"""encode `fix_len` length for token_data to get indices list in `self.vocab`
if one sentence length is shorter than fix_len, it will use pad word for padding to fix_len
if one sentence length is longer than fix_len, it will cut the first max_words words
Expand All @@ -96,9 +92,9 @@ def __encode_tokens(self, tokens, pad_idx) -> torch.Tensor:
Returns:
integer list of indices with `fix_len` length for tokens input
"""
x = [pad_idx for _ in range(self.fix_len)]
for idx, word in enumerate(tokens[:self.fix_len]):
x[idx] = self.vocab.get_index(word)
x = [pad_idx for _ in range(fix_len)]
for idx, word in enumerate(tokens[:fix_len]):
x[idx] = vocab.get_index(word)
return torch.tensor(x)

def __len__(self):
Expand Down