Hello everybody,
I am currently running a Seq2Seq model, using the dataset
Multi30k
from
torchtext
. In particular, I structured the code in order to use
torch.utils.data.DataLoader
as follows:
# Define tokenizer
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
# Define special tokens
specials = ['<unk>', '<pad>', '<bos>', '<eos>']
# Load the Multi30k dataset
train_data, valid_data, test_data = Multi30k()
de_vocab = build_vocab_from_iterator(yield_tokens(train_data, de_tokenizer),
specials=specials,
min_freq=2)
en_vocab = build_vocab_from_iterator(yield_tokens(train_data, en_tokenizer),
specials=specials,
min_freq=2)
# Add the <unk> token to the vocabulary
de_vocab.set_default_index(de_vocab['<unk>'])
en_vocab.set_default_index(en_vocab['<unk>'])
# Get the index of the <pad> token
pad_idx = de_vocab['<pad>']
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_iterator = DataLoader(train_data,
batch_size=BATCH_SIZE,
collate_fn=collate_fn)
valid_iterator = DataLoader(valid_data,
batch_size=BATCH_SIZE,
collate_fn=collate_fn)
test_iterator = DataLoader(test_data,
batch_size=BATCH_SIZE,
collate_fn=collate_fn)
hidden_size = 256
# Create model and define optimizer
encoder = Encoder(len(de_vocab), hidden_size).to(device)
decoder = Decoder(hidden_size, len(en_vocab)).to(device)
model = Seq2Seq(encoder, decoder, device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(EPOCHS):
for batch in train_iterator:
input_seq, target_seq = batch
input_seq = input_seq.to(device)
target_seq = target_seq.to(device)
output_seq = model(input_seq, target_seq)
# Compute loss and perform backpropagation
loss = F.cross_entropy(output_seq.view(-1, len(en_vocab)), target_seq.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print some information during training
print(f"Epoch: {epoch + 1}, Loss: {loss.item():.4f}")
# Validation loop
model.eval()
with torch.no_grad():
total_loss = 0
num_items = 0
for batch in valid_iterator:
input_seq, target_seq = batch
output_seq = model(input_seq, target_seq)
val_loss = F.cross_entropy(output_seq.view(-1, len(en_vocab)), target_seq.view(-1))
total_loss += val_loss.item()
num_items += len(batch)
average_val_loss = total_loss / num_items
print(f"Epoch: {epoch + 1}, Validation Loss: {average_val_loss:.4f}")
# Testing loop
model.eval()
with torch.no_grad():
total_test_loss = 0
test_num_items = 0
for batch in test_iterator:
input_seq = batch.src
target_seq = batch.trg
print(input_seq)
print(target_seq)
output_seq = model(input_seq, target_seq)
test_loss = F.cross_entropy(output_seq.view(-1, len(en_vocab)), target_seq.view(-1))
total_test_loss += test_loss.item()
test_num_items += len(batch)
average_test_loss = total_test_loss / test_num_items
print(f"Test Loss: {average_test_loss:.4f}")
Everything runs perfectly at training and validation time. However, when I am running the the test step I have tan error at for batch in test_iterator
:
Traceback (most recent call last):
File "mymodel.py", line 202, in <module>
for batch in test_iterator:
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
data = self._next_data()
^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
data.append(next(self.dataset_iter))
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py", line 154, in __next__
return self._get_next()
^^^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py", line 142, in _get_next
result = next(self.iterator)
^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py", line 226, in wrap_next
result = next_func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/datapipe.py", line 381, in __next__
return next(self._datapipe_iter)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py", line 183, in wrap_generator
response = gen.send(None)
^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/iter/sharding.py", line 75, in __iter__
for i, item in enumerate(self.source_datapipe):
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py", line 183, in wrap_generator
response = gen.send(None)
^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/iter/combinatorics.py", line 124, in __iter__
yield from self.datapipe
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py", line 183, in wrap_generator
response = gen.send(None)
^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/iter/combining.py", line 624, in __iter__
yield from zip(*iterators)
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py", line 183, in wrap_generator
response = gen.send(None)
^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torchdata/datapipes/iter/util/plain_text_reader.py", line 138, in __iter__
yield from self._helper.return_path(stream, path=path) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torchdata/datapipes/iter/util/plain_text_reader.py", line 69, in return_path
yield from stream
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torchdata/datapipes/iter/util/plain_text_reader.py", line 62, in decode
yield from stream
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torchdata/datapipes/iter/util/plain_text_reader.py", line 54, in strip_newline
for line in stream:
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torchdata/datapipes/iter/util/plain_text_reader.py", line 45, in skip_lines
yield from file
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/utils/data/datapipes/utils/common.py", line 369, in __iter__
yield from self.file_obj
File "<frozen codecs>", line 322, in decode
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 37: invalid start byte
I am using torch
version 2.1.1 and torchtext
version 0.16.1.
I tried to understand where the error is from, applyign some modifications and try...except
to my helper functions, but no luck yet. These are the function I am using for yielding tokens and collate:
def yield_tokens(data_iter: Iterable, tokenizer) -> List[str]:
for _, text in data_iter:
yield tokenizer(text)
except UnicodeDecodeError:
yield "" # or some other default value
def collate_fn(batch):
de_batch, en_batch = [], []
for (de_item, en_item) in batch:
de_batch.append(torch.tensor([de_vocab[token] for token in de_tokenizer(de_item)], dtype=torch.long))
en_batch.append(torch.tensor([en_vocab[token] for token in en_tokenizer(en_item)], dtype=torch.long))
de_batch = pad_sequence(de_batch, padding_value=pad_idx)
en_batch = pad_sequence(en_batch, padding_value=pad_idx)
return de_batch, en_batch
Thanks very much for your help 
It seems that your dataset contains bytes that UTF-8 cannot read.
Are you able to inspect the detailed file contents?
File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torchdata/datapipes/iter/util/plain_text_reader.py", line 138, in __iter__
yield from self._helper.return_path(stream, path=path)
You can navigate to the above source file and set a breakpoint()
before that statement and check the file path
It seems ‘test’ dataset from Multi30k was corrupted.
Multi30k uses the following URLs to retrive the data:
# TODO: Update URL to original once the server is back up (see https://github.com/pytorch/text/issues/1756)
URL = {
"train": r"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz",
"valid": r"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz",
"test": r"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/mmt16_task1_test.tar.gz",
Now, if we look at the test dataset mmt16_task1_test.tar.gz, in addition to the expected test.de, test.en, and test.fr the dataset strangely contains extra 3 files ._test.de, ._test.en, and ._test.fr files:
The extra ._test.de, ._test.en, and ._test.fr files seem to be Apple metadata, however Multi30k code picks them up instead of the desired files because the _filter_fn function just searches for the filename in the file:
def _filter_fn(split, language_pair, i, x):
return f"{_PREFIX[split]}.{language_pair[i]}" in x[0]
There are 2 simple fixes:
Fix option 1) Find mmt16_task1_test.tar.gz in pytorch cache folder (by default ~/.data/datasets/Multi30k) and manually unzip the correct files.
Fix option 2) Patch multi30k._filter_fn function to find correct file:
def _filter_fn(split, language_pair, i, x):
return f"/{torchtext.datasets.multi30k._PREFIX[split]}.{language_pair[i]}" in x[0]
torchtext.datasets.multi30k._filter_fn = _filter_fn
Someone might want to raise this to the data owner to fix the original file.