Albert-Based Internet Sentence Text 10 Classification Project

Overall project introduction:

 

 

Project data introduction:

 label:

Main model code:

# -*- coding: utf-8 -*-


import os
import tensorflow as tf
import modeling
import optimization
from utils import time_now_string
from hyperparameters import Hyperparamters as hp
from classifier_utils import ClassifyProcessor



num_labels = hp.num_labels
processor = ClassifyProcessor() 
bert_config_file = os.path.join(hp.bert_path,'albert_config.json')
bert_config = modeling.AlbertConfig.from_json_file(bert_config_file)



class NetworkAlbert(object):
    def __init__(self,is_training):
        # Training or not
        self.is_training = is_training    
        
        # Placeholder       
        self.input_ids = tf.placeholder(tf.int32, shape=[None, hp.sequence_length], name='input_ids')
        self.input_masks = tf.placeholder(tf.int32, shape=[None,  hp.sequence_length], name='input_masks')
        self.segment_ids = tf.placeholder(tf.int32, shape=[None,  hp.sequence_length], name='segment_ids')
        self.label_ids = tf.placeholder(tf.int32, shape=[None,hp.num_labels], name='label_ids')
               
        # Load BERT model
        self.model = modeling.AlbertModel(
                                    config=bert_config,
                                    is_training=self.is_training,
                                    input_ids=self.input_ids,
                                    input_mask=self.input_masks,
                                    token_type_ids=self.segment_ids,
                                    use_one_hot_embeddings=False)


        # Get the feature vector by BERT
        output_layer = self.model.get_pooled_output()            
        print('output_layer',output_layer)#(?, 384)
                
        # Hidden size 
        hidden_size = output_layer.shape[-1].value    
           
        with tf.name_scope("Full-connection"):  
            loss_num_label = []
            logits_num_label = []
            for i in range(hp.num_labels):
                output_weights = tf.get_variable(
            	              "output_weights%s"%str(i), [2, hidden_size],
            	              initializer=tf.truncated_normal_initializer(stddev=0.02))        
                output_bias = tf.get_variable(
            	              "output_bias%s"%str(i), [2], initializer=tf.zeros_initializer())# 
                logits = tf.matmul(output_layer, output_weights, transpose_b=True)
                logits = tf.nn.bias_add(logits, output_bias)
                logits_num_label.append(logits)
                one_hot_labels = tf.one_hot(self.label_ids[:,i], depth=2, dtype=tf.int32)
                per_example_loss = tf.nn.softmax_cross_entropy_with_logits(labels=one_hot_labels,logits=logits)
                loss_num_label.append(tf.reduce_mean(per_example_loss))
            self.logits_num_label = tf.transpose(tf.stack(logits_num_label, 0),[1,0,2])
            self.loss_num_label = tf.stack(loss_num_label, 0)
            self.probabilities = tf.nn.sigmoid(self.logits_num_label)


        with tf.variable_scope("Prediction"):             
            # Prediction               
            self.predictions = tf.to_int32(tf.argmax(self.probabilities,2)) 
   
        with tf.variable_scope("loss"):            
            # Summary for tensorboard
            if self.is_training:
	            self.accuracy = tf.reduce_mean(tf.to_float(tf.equal(self.predictions, self.label_ids)))
	            tf.summary.scalar('accuracy', self.accuracy) 
                                               
            # Initial embedding by BERT
            ckpt = tf.train.get_checkpoint_state(hp.saved_model_path)
            checkpoint_suffix = ".index"
            if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path + checkpoint_suffix):
                print('='*10,'Restoring model from checkpoint!','='*10)
                print("%s - Restoring model from checkpoint ~%s" % (time_now_string(),
                                                                    ckpt.model_checkpoint_path))
            else:                   
                print('='*10,'First time load BERT model!','='*10)
                tvars = tf.trainable_variables()
                if hp.init_checkpoint:
                   (assignment_map, initialized_variable_names) = \
                     modeling.get_assignment_map_from_checkpoint(tvars,
                                                                 hp.init_checkpoint)
                   tf.train.init_from_checkpoint(hp.init_checkpoint, assignment_map)
                                
            # Loss and Optimizer
            if self.is_training:
                # Global_step
                self.global_step = tf.Variable(0, name='global_step', trainable=False)                   
                self.loss = tf.reduce_mean(self.loss_num_label)

                # Optimizer BERT
                train_examples = processor.get_train_examples(hp.data_dir)
                num_train_steps = int(
                    len(train_examples) / hp.batch_size * hp.num_train_epochs)
                num_warmup_steps = int(num_train_steps * hp.warmup_proportion)
                print('num_train_steps',num_train_steps)
                self.optimizer = optimization.create_optimizer(self.loss,
                                                                hp.learning_rate, 
                                                                num_train_steps, 
                                                                num_warmup_steps,
                                                                hp.use_tpu,
                                                                Global_step=self.global_step)    

                # Summary for tensorboard                 
                tf.summary.scalar('loss', self.loss)
                self.merged = tf.summary.merge_all()
                
                
                
                
if __name__ == '__main__':
    # Load model
    albert = NetworkAlbert(is_training=True)










Complete data code: https://download.csdn.net/download/weixin_55771290/87655042

Guess you like

Origin blog.csdn.net/pythonyanyan/article/details/129975918