onehot text classification
1. Read document data
// An highlighted block
with open(os.path.join("..", "data", "train" + ".txt"), encoding="utf-8") as f:
all_data = f.read().split("\n")
print('\n', all_data)
2. Split data
// An highlighted block
with open(os.path.join("..", "data", "train" + ".txt"), encoding="utf-8") as f:
all_data = f.read().split("\n")
texts = []
labels = []
for data in all_data:
if data:
t,l = data.split("\t")
texts.append(t)
labels.append(l)
print(texts,labels
result:
3. Build word2idx and idx2onehot
def built_curpus(train_texts):
word_2_index = {
"<PAD>":0,"<UNK>":1}
for text in train_texts:
for word in text:
word_2_index[word] = word_2_index.get(word,len(word_2_index))
print(word_2_index)
return word_2_index, np.eye(len(word_2_index), dtype=np.float32)
result:
4. Build the dataset
class OhDataset(Dataset):
def __init__(self,texts,labels,word_2_index,index_2_onehot,max_len):
self.texts = texts
self.labels = labels
self.word_2_index = word_2_index
self.index_2_onehot = index_2_onehot
self.max_len = max_len
def __getitem__(self, index):
# 1. 根据index获取数据
text = self.texts[index]
label= int(self.labels[index])
# 2. 填充裁剪数据长度至max_len
text = text[:self.max_len] # 裁剪
# 3. 将 中文文本----> index -----> onehot 形式
text_index = [word_2_index.get(i,1) for i in text] # 中文文本----> index
text_index = text_index + [0] * (self.max_len - len(text_index)) # 填充
text_onehot = self.index_2_onehot[text_index]
return text_onehot,label
def __len__(self):
return len(self.labels)
test_dataset = OhDataset(test_texts, test_labels, word_2_index, index_2_onehot, max_len)
test_dataloader = DataLoader(test_dataset, 10, shuffle=False)
5. Building a model
class OhModel(nn.Module):
def __init__(self,curpus_len,hidden_num,class_num,max_len):
super().__init__()
self.linear1 = nn.Linear(curpus_len,hidden_num)
self.active = nn.ReLU()
self.flatten = nn.Flatten()
self.linear2 = nn.Linear(max_len*hidden_num,class_num)
self.cross_loss = nn.CrossEntropyLoss()
def forward(self,text_onehot,labels=None):
hidden = self.linear1.forward(text_onehot)
hidden_act = self.active(hidden)
hidden_f = self.flatten(hidden_act)
p = self.linear2(hidden_f)
self.pre = torch.argmax(p,dim=-1).detach().cpu().numpy().tolist()
if labels is not None:
loss = self.cross_loss(p,lables)
return loss