mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-23 09:29:14 +08:00
tutorials are added
This commit is contained in:
46
tutorials/09 - Language Model/data_utils.py
Normal file
46
tutorials/09 - Language Model/data_utils.py
Normal file
@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import os
|
||||
|
||||
class Dictionary(object):
|
||||
def __init__(self):
|
||||
self.word2idx = {}
|
||||
self.idx2word = {}
|
||||
self.idx = 0
|
||||
|
||||
def add_word(self, word):
|
||||
if not word in self.word2idx:
|
||||
self.word2idx[word] = self.idx
|
||||
self.idx2word[self.idx] = word
|
||||
self.idx += 1
|
||||
|
||||
def __len__(self):
|
||||
return len(self.word2idx)
|
||||
|
||||
class Corpus(object):
|
||||
def __init__(self, path='./data'):
|
||||
self.dictionary = Dictionary()
|
||||
self.train = os.path.join(path, 'train.txt')
|
||||
self.test = os.path.join(path, 'test.txt')
|
||||
|
||||
def get_data(self, path, batch_size=20):
|
||||
# Add words to the dictionary
|
||||
with open(path, 'r') as f:
|
||||
tokens = 0
|
||||
for line in f:
|
||||
words = line.split() + ['<eos>']
|
||||
tokens += len(words)
|
||||
for word in words:
|
||||
self.dictionary.add_word(word)
|
||||
|
||||
# Tokenize the file content
|
||||
ids = torch.LongTensor(tokens)
|
||||
token = 0
|
||||
with open(path, 'r') as f:
|
||||
for line in f:
|
||||
words = line.split() + ['<eos>']
|
||||
for word in words:
|
||||
ids[token] = self.dictionary.word2idx[word]
|
||||
token += 1
|
||||
num_batches = ids.size(0) // batch_size
|
||||
ids = ids[:num_batches*batch_size]
|
||||
return ids.view(batch_size, -1)
|
Reference in New Issue
Block a user