bert模型及代码

参考:https://blog.csdn.net/IT__learning/article/details/120741368

# bert模型
# from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
from transformers import BertModel, BertTokenizer, BertConfig
import torch

# 导入预训练模型
model_name = 'hfl/chinese-roberta-wwm-ext'
config = BertConfig.from_pretrained(model_name)  # 这个方法会自动从官方的s3数据库下载模型配置、参数等信息(代码中已配置好位置)
tokenizer = BertTokenizer.from_pretrained(model_name)  # 这个方法会自动从官方的s3数据库读取文件下的vocab.txt文件
model = BertModel.from_pretrained(model_name)  # 这个方法会自动从官方的s3数据库下载模型信息

'''
sa, sb = "我爱武汉!我爱中国!",  "锤子?"
# 分词是tokenizer.tokenize  分词并转化为id是tokenizer.encode
input_id = tokenizer.encode(sa)
input_id = torch.tensor([input_id]) # 输入数据是tensor且batch形式
sequence_output, pooled_output = model(input_id) #  输出形状分别是[1, 9, 768], [1, 768]
# 但是输入的bertmodel还需要知识前后句子的信息的token type ,以及遮掉PAD部分的attention mask
# print("input_id: ", input_id) # tensor([[ 101, 3330, 4635, 2897,  749,  702, 7237, 2094,  102]]) 前面101,后面102是固定的
# print("sequence_output: ", sequence_output)  # last_hidden_state
# print("pooled_output: ", pooled_output)  # pooler_output

inputs = tokenizer.encode_plus(sa, text_pair=sb, return_tensors="pt")  # 还有些常用的可选参数max_length, pad_to_max_length等
print(inputs.keys())  # 返回的是一个包含id, mask信息的字典
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask']
# input_ids 单词在词典中的编码
# token_type_ids : 区分里两个句子的编码
# attention_mask: 指定对哪些词进行self_attention 的操作
sequence_output, pooled_output = model(**inputs)
'''

# 利用分词器进行编码
# encode仅返回input_ids
ret = tokenizer.encode("我爱你")
# print(ret)  # [101, 2769, 4263, 872, 102]

# encode_plus返回所有的编码信息
input_id = tokenizer._encode_plus("我爱你", "你也爱我")
print("input_id: ", input_id)
# input_id:
# {'input_ids': [101, 2769, 4263, 872, 102, 872, 738, 4263, 2769, 102],
# 'token_type_ids': [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
# 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

# 将分词结果输入模型,得到编码
# 添加batch维度并转化为tensor
input_ids = torch.tensor(input_id["input_ids"])
token_type_ids = torch.tensor(input_id["token_type_ids"])
attention_mask_ids = torch.tensor(input_id["attention_mask"])

# 将模型转化为eval模式  为什么要在eval模式
model.eval()
# Batchnorm层在train模式下会一直根据输入数据更新其mean和variance参数,这个过程不同于其他网络层的反向传播,优化器更新参数。
# Batchnorm里面具有两个可变的参数,估计了数据集合分布所产生的特征每一维的均值和方差。而bn就是根据这两个值去给特征做标准化处理。在train模式下,每个batch的数据产生的特征都会用来更新这两个参数。训练完成后,这两个参数就描述了训练集的特征分布。
# 如果在测试时,不使用eval模式,那么就会使得这两个参数继续更新,并且每一次测试都会因为更新导致最终的结果不一样。
# 使用eval模式则禁止了bn参数的更新。使得测试结果是稳定的
device = "cpu"
tokens_tensor = input_ids.to(device).unsqueeze(0)
segments_tensors = token_type_ids.to(device).unsqueeze(0)
attention_mask_ids_tensors = attention_mask_ids.to(device).unsqueeze(0)
# 先把模型加载到指定设备上
# torch.unsqueeze()这个函数主要是对数据维度进行扩充。
# 原来[6] unsqueeze(0) 就是在第0维度加上1 变为[1,6],
model.to(device)
print(input_ids)  # tensor([ 101, 2769, 4263,  872,  102,  872,  738, 4263, 2769,  102])
print(tokens_tensor)  # tensor([[ 101, 2769, 4263,  872,  102,  872,  738, 4263, 2769,  102]])
print(segments_tensors)  # tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]])
print(attention_mask_ids)  # tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
print(attention_mask_ids_tensors)  # tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

# 进行编码
# 进行编码
with torch.no_grad():
    # See the models docstrings for the detail of the inputs
    outputs = model(tokens_tensor, segments_tensors, attention_mask_ids_tensors)
    # Transformers models always output tuples.
    # See the models docstrings for the detail of all the outputs
    # In our case, the first element is the hidden state of the last layer of the Bert model
    encoded_layers = outputs
print(encoded_layers)
# 得到最终的编码结果encoded_layers


'''
text1 = '我爱武汉!我爱中国!'
tokeniz_text1 = tokenizer.tokenize(text1) # 单字分离
print(tokeniz_text1)  # ['我', '爱', '武', '汉', '!', '我', '爱', '中', '国', '!']
print('tokeniz_text1:', len(tokeniz_text1))  # 10

indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokeniz_text1) # 这里其实就是对应字典里卖弄的字母进行转换成当前的id
print('len(indexed_tokens_1):', len(indexed_tokens_1))  # 10
print(indexed_tokens_1)  # [2769, 4263, 3636, 3727, 8013, 2769, 4263, 704, 1744, 8013] 转换为id编码


input_ids_1 = indexed_tokens_1
segments_ids_1 = [0]*len(input_ids_1)  # 其实这个输入可以不用的,因为是单句的原因
input_masks_1 = [1]*len(input_ids_1)  # 其实这个输入可以不用的,因为是单句的原因


input_ids_1_tensor = torch.tensor([input_ids_1])  # 变成张量??
print("input_id_1_tensor: ", input_ids_1_tensor)  #  tensor([[2769, 4263, 3636, 3727, 8013, 2769, 4263,  704, 1744, 8013]])
vector1, pooler1 = model(input_ids_1_tensor)  # 应该是输入3个向量的,但是单句情况下,它自会自己做判断,然后自动生成对应的segments_ids和input_masks向量
#这里的输出最后一层的last_hidden_state和最后一层首个token的hidden-state
print("vector1: ", vector1)  # vector1:  last_hidden_state
print("pooler1: ", pooler1)  # pooler1:  pooler_output
'''

输出:

BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=tensor([[[ 3.3619e-01, -2.7759e-02, -2.1758e-01,  ..., -7.5315e-01,
          -2.2568e-01, -6.2780e-01],
         [ 9.7733e-01,  8.8764e-04,  8.3537e-01,  ..., -9.8692e-01,
          -6.1535e-01, -5.0042e-01],
         [ 4.1397e-01,  1.8517e-01, -4.3048e-01,  ..., -4.7202e-01,
          -4.3097e-01, -7.4553e-01],
         ...,
         [ 7.4375e-01,  6.0449e-01, -4.6755e-01,  ..., -5.4740e-01,
           1.0082e-01, -1.3955e+00],
         [ 1.5120e+00,  2.4032e-01, -1.6719e-01,  ..., -1.3916e+00,
           6.3403e-01, -3.8112e-01],
         [ 3.3619e-01, -2.7759e-02, -2.1758e-01,  ..., -7.5315e-01,
          -2.2568e-01, -6.2780e-01]]]), 
pooler_output=tensor([[ 9.9075e-01,  8.8019e-01,  9.8010e-01,  9.0659e-01, -6.8859e-02,
          6.6336e-01, -8.0372e-01,  4.7698e-02,  1.3977e-01, -9.2291e-01,
          9.9279e-01,  9.1740e-01, -2.2254e-01, -6.9465e-01,  8.0369e-01,
         -9.3448e-01,  8.6409e-01,  8.0418e-01,  3.0517e-01, -4.6521e-01,
          9.6722e-01, -9.8768e-01, -8.6023e-01,  3.2518e-01,  6.1657e-01,
          2.4652e-01, -4.0913e-02,  5.7496e-01, -9.9760e-01,  9.6480e-01,
          8.1001e-01,  9.5326e-01,  3.1076e-01, -9.9700e-01, -9.9672e-01,
          1.9551e-01,  2.2945e-01,  9.7128e-01,  4.7522e-01, -7.9412e-01,
         -8.3976e-01,  3.8563e-02, -1.2645e-01, -8.8102e-01,  3.2514e-01,
          6.7356e-01, -9.8475e-01, -9.6979e-01, -5.9761e-01,  9.3558e-01,
         -6.9097e-01, -9.8588e-01,  8.1895e-01,  1.0591e-01,  6.3736e-01,
          9.7258e-01, -9.2867e-01, -1.4969e-01,  9.5941e-01,  2.1981e-01,
          9.8704e-01, -8.4270e-01,  2.2037e-03, -9.8017e-01,  9.8842e-01,
         -9.8152e-01, -9.3376e-01,  7.7887e-01,  7.7475e-01,  9.9533e-01,
          3.8658e-01,  9.9452e-01,  9.9377e-01, -2.2188e-01, -2.8200e-01,
          9.2902e-01, -1.0011e-01,  8.9583e-01, -9.9840e-01, -6.5732e-01,
          9.9158e-01,  7.5309e-01, -9.8224e-01, -3.8741e-01, -9.4533e-01,
         -9.1076e-01, -8.1482e-01,  9.8495e-01, -5.0759e-01,  8.4789e-01,
          9.9171e-01, -8.9008e-01, -9.9756e-01,  8.0427e-01, -8.7783e-01,
         -3.7953e-01, -9.1776e-01,  9.7963e-01, -4.7435e-01, -9.1977e-01,
          2.0677e-02,  2.8762e-01, -9.4367e-01, -9.8268e-01, -9.5572e-04,
          9.8685e-01,  6.2690e-01, -9.7210e-01,  9.9640e-01,  7.1080e-01,
         -9.8926e-01, -9.1698e-01, -9.5707e-01,  5.9367e-01, -8.5865e-01,
          9.8935e-01,  5.5113e-01,  7.5791e-01,  5.2586e-01, -9.8256e-01,
          9.6661e-01, -8.4953e-01, -9.5503e-01, -7.1265e-01,  7.9787e-01,
          9.8453e-01,  9.8765e-01, -5.5303e-01,  9.4230e-01,  9.8650e-01,
         -4.1340e-01,  7.9913e-01, -9.8005e-01,  9.7139e-01,  8.6728e-01,
         -9.7159e-01,  3.0632e-01, -6.1375e-01,  9.9709e-01,  9.9124e-01,
          1.1738e-01, -2.2069e-01,  9.9119e-01, -8.9408e-01,  9.8408e-01,
         -9.9880e-01,  9.2978e-01, -9.9388e-01, -6.2622e-01,  8.6662e-01,
         -2.8134e-01,  9.9733e-01,  3.5472e-01,  9.9662e-01, -9.8544e-01,
         -9.7200e-01,  2.8555e-01,  3.8724e-01,  9.7293e-01, -9.6584e-01,
          9.2593e-01,  6.5180e-01, -3.1652e-01,  8.9430e-01, -9.5421e-01,
          9.5854e-01, -8.7930e-01,  9.8199e-01,  6.6048e-01, -8.2196e-01,
         -6.8046e-01, -7.9577e-01,  6.2847e-01, -9.4839e-01, -7.9327e-01,
          6.5891e-01, -9.7378e-01,  9.9256e-01,  3.2039e-01,  2.9891e-01,
          4.7479e-01, -1.0923e-01, -9.3204e-01,  9.6215e-01, -7.9341e-01,
          9.5833e-01,  5.2322e-01, -1.2191e-01,  4.6280e-01, -5.6495e-01,
         -9.3658e-01,  9.4983e-01,  8.1630e-01, -3.2789e-01,  9.6891e-01,
          8.2936e-01, -8.6533e-01, -9.3864e-01, -9.9366e-01, -6.8551e-01,
          9.9765e-01, -8.7289e-01, -9.0152e-01,  2.9068e-01, -9.7537e-01,
          3.3742e-01, -6.6504e-01, -3.8104e-01, -5.4068e-01, -9.8987e-01,
         -2.3031e-01, -8.5811e-01, -9.3427e-01,  4.4615e-01,  5.1952e-01,
         -1.3914e-01, -9.6551e-01,  9.6004e-02,  8.0498e-01,  5.7643e-01,
          8.2637e-01,  1.3084e-03, -9.1321e-01,  6.5838e-01, -3.1919e-01,
          6.0519e-01,  9.9446e-01,  9.8212e-01,  8.6537e-01, -4.0425e-01,
          6.6707e-01,  9.0727e-01,  9.0798e-01, -9.9320e-01,  8.8693e-01,
         -8.3804e-01, -8.3248e-01,  9.8565e-01, -9.6084e-01,  9.9015e-01,
          9.9623e-01, -4.7055e-01,  9.9439e-01, -5.0349e-01, -9.5164e-01,
         -9.6765e-01,  9.8696e-01,  5.3022e-01,  9.9523e-01, -9.2552e-01,
         -6.6422e-01,  3.3073e-01, -2.8775e-01, -9.7745e-01, -9.8756e-01,
          5.5153e-01,  9.5657e-01,  9.9685e-01,  5.9014e-01, -9.1637e-01,
         -9.0183e-01, -9.8659e-01,  9.9359e-01, -7.7251e-01,  9.1658e-01,
          8.7679e-01, -2.2917e-01,  1.6624e-01,  6.6297e-01, -7.7114e-01,
         -9.0459e-01,  2.9962e-01, -9.9692e-01, -8.8063e-01, -9.8742e-01,
          9.5472e-01, -8.9454e-01, -9.9711e-01,  7.3280e-01,  9.8919e-01,
          2.9475e-01, -9.9716e-01,  9.4168e-01,  9.5014e-01,  6.6222e-01,
         -7.4989e-01,  8.8822e-01, -9.9807e-01,  9.9565e-01, -9.6637e-01,
          9.7876e-01, -9.5856e-01, -9.7538e-01,  6.1337e-01,  9.5822e-01,
          9.9282e-01, -7.8862e-01,  8.7218e-01, -9.6176e-01, -7.3216e-01,
          5.7915e-01,  9.9128e-01, -8.4507e-01, -9.5550e-02, -9.1646e-01,
         -3.6777e-02,  7.7876e-01, -9.3901e-01, -9.6444e-01,  2.7782e-01,
          9.7593e-01, -9.1520e-01,  9.9445e-01,  9.9301e-01,  9.9936e-01,
         -3.0505e-01, -9.0582e-01,  9.8159e-01,  4.9356e-01,  8.8150e-01,
         -3.3978e-01, -1.7882e-01,  9.6970e-01,  1.7612e-01,  5.7382e-02,
         -9.9074e-01,  9.8063e-01, -1.3402e-01,  3.6539e-01,  3.7799e-01,
         -8.9263e-01,  5.6554e-03,  9.5904e-01, -9.4293e-01,  8.6794e-01,
         -8.9813e-01, -4.4103e-01, -4.6201e-01,  9.7715e-01,  6.5785e-01,
          5.2142e-01, -9.0151e-01,  9.9661e-01, -9.9514e-01,  8.1081e-01,
         -9.9090e-01,  9.7773e-01, -7.1779e-01, -6.8812e-02, -8.1079e-01,
         -9.7344e-01,  9.7383e-01,  9.7896e-01,  7.1118e-01,  9.7042e-01,
         -9.1477e-01,  9.9020e-01,  1.5275e-01,  6.6742e-01,  7.9810e-01,
          2.7674e-01,  9.8832e-01, -9.6513e-01, -5.0568e-01, -3.7420e-01,
         -9.6967e-01, -7.4315e-01, -9.9261e-01,  3.1433e-01, -8.6818e-01,
         -8.8477e-01, -1.2568e-01, -6.6616e-01, -6.3178e-01, -2.5746e-01,
         -8.8670e-01, -2.8604e-01,  4.3133e-01,  8.4542e-01,  5.3659e-01,
          8.6251e-01, -9.3415e-01, -3.0158e-01, -9.9289e-01, -9.6454e-01,
          3.5056e-01,  9.9525e-01, -9.9724e-01,  8.7703e-01, -9.9027e-01,
         -3.6474e-01,  9.3585e-02, -9.6119e-01, -7.6244e-01, -9.9414e-02,
         -9.6582e-01,  9.5777e-01,  8.5988e-01,  9.9349e-01,  9.5194e-01,
          9.4847e-01,  1.8881e-01, -7.0119e-01, -9.9107e-01, -9.9520e-01,
         -9.9731e-01, -9.9207e-01, -6.5993e-01, -5.3071e-01, -9.7917e-01,
         -1.3765e-01,  9.5647e-01,  9.8538e-01,  9.3668e-01, -9.9150e-01,
          1.9161e-01, -9.7256e-01, -3.1716e-01,  9.7351e-01, -5.1536e-01,
         -9.1099e-01,  6.7148e-01,  3.7685e-01,  9.6222e-01, -7.1956e-01,
          1.3906e-01,  5.5813e-01,  4.7383e-01,  5.9820e-01, -9.8319e-01,
          4.6038e-01,  9.9467e-01,  5.7708e-01, -9.7716e-01, -8.7633e-01,
          3.5613e-02, -9.7684e-01, -6.4635e-01,  7.0321e-01,  9.9561e-01,
         -9.9805e-01, -8.5880e-01, -8.7899e-01,  9.1699e-01,  9.2264e-01,
          8.3501e-01,  9.7814e-01,  4.0337e-03,  9.3792e-01,  3.2318e-01,
          5.7532e-01,  9.4464e-01,  8.5362e-02, -9.6020e-01,  9.4807e-01,
         -8.0894e-01,  3.6588e-01, -9.3664e-01,  9.5620e-01, -6.8529e-01,
          9.9593e-01,  2.8142e-01, -5.2609e-01, -9.0014e-01, -9.7552e-01,
          9.4329e-01,  9.9265e-01, -8.6364e-01, -4.0074e-01, -9.6222e-01,
         -9.9783e-01, -9.6650e-01, -9.5296e-01,  7.2153e-01, -8.7630e-01,
         -9.2950e-01,  7.7295e-01,  6.4717e-01,  9.9794e-01,  9.9708e-01,
          9.9675e-01, -9.3763e-01, -7.9482e-01,  9.8510e-01, -6.4673e-01,
          6.8100e-02, -8.7373e-02, -9.7364e-01, -9.6421e-01, -9.6619e-01,
          9.3707e-01, -5.2873e-02, -5.5875e-01, -3.6079e-01,  6.0930e-01,
          5.1426e-01, -9.8450e-01, -7.5749e-01, -9.0202e-01,  9.2969e-01,
          9.9343e-01, -7.1519e-01,  6.8092e-01, -8.6145e-01, -5.7094e-02,
          6.5816e-01,  9.4895e-01,  9.8605e-01, -9.7553e-01, -1.2325e-01,
         -6.2374e-01,  8.1821e-02,  7.9016e-01,  8.6773e-01, -3.4944e-01,
         -7.6296e-01,  9.4143e-01, -4.1394e-01,  5.8002e-01, -4.1226e-01,
          9.2671e-01,  6.6505e-01,  9.8576e-01,  8.8508e-01,  9.6077e-01,
         -6.2898e-01,  8.1327e-01,  9.8437e-01, -2.3831e-01,  5.6440e-01,
          1.8547e-01, -7.3295e-01,  2.2297e-01,  9.6953e-01,  8.1106e-01,
          8.0254e-02,  5.5427e-01, -9.6749e-01,  9.6367e-01,  9.0276e-01,
          9.9659e-01, -1.3418e-01,  9.8862e-01,  6.5414e-01,  8.9400e-01,
          5.4925e-01,  8.1796e-01,  6.7230e-01,  6.6691e-01,  9.8366e-01,
          9.9401e-01, -9.7574e-01, -8.6070e-01, -9.9457e-01,  9.9344e-01,
          8.6682e-01,  5.1204e-01, -9.8434e-01,  9.6626e-01, -5.7903e-01,
         -5.9320e-01,  9.4739e-01,  5.8294e-01, -9.9096e-01,  9.8690e-01,
         -9.1423e-01,  4.5894e-01, -6.2185e-01,  5.9836e-01, -4.1464e-01,
          9.8955e-01, -9.4859e-01,  3.0803e-01,  9.8623e-01,  4.6205e-01,
          9.3082e-01,  4.7279e-01, -9.8723e-01,  8.8344e-01, -3.2326e-01,
         -9.6772e-01,  5.4376e-01,  9.9637e-01,  9.7205e-01,  1.1166e-01,
         -5.1469e-01,  8.4360e-01, -9.0869e-01,  9.7438e-01, -9.7463e-01,
         -4.9179e-01, -9.2429e-01,  9.8075e-01, -9.5804e-01, -9.9481e-01,
         -3.1799e-01,  9.5663e-01,  7.4667e-03,  6.0318e-01,  9.9020e-01,
         -3.1829e-01, -9.7605e-01, -8.4347e-01,  4.8133e-01, -9.3339e-01,
         -9.5670e-01,  1.4778e-01, -8.8160e-01, -7.5434e-01, -9.7543e-01,
          3.7120e-02, -9.6775e-01, -9.6678e-01,  9.9479e-01, -9.5106e-01,
         -9.3267e-01,  9.9364e-01, -2.6021e-01, -9.8432e-01,  2.5690e-01,
         -5.8735e-01, -5.7718e-01,  8.1178e-01,  1.0098e-01,  2.0009e-01,
         -9.9981e-01,  7.3456e-01,  9.9324e-01, -9.7533e-01, -8.7010e-01,
         -6.6386e-01, -4.0919e-01,  2.1820e-02,  6.3431e-01,  8.0918e-01,
         -5.3134e-01,  5.4981e-01, -2.7006e-01,  8.4390e-01, -5.3361e-02,
          8.5153e-01, -4.4415e-01, -4.8104e-01, -8.6052e-01, -8.9018e-01,
         -9.9670e-01, -9.6197e-01,  9.9331e-01,  8.9535e-01,  9.9535e-01,
         -8.0414e-01, -5.3369e-01,  8.9303e-01,  9.6344e-01, -9.9334e-01,
         -4.8010e-01,  6.5306e-01,  5.7623e-01, -1.2370e-01, -8.1903e-01,
          5.7525e-01, -9.9444e-01, -8.5298e-01,  5.5445e-01, -9.7625e-02,
         -6.7029e-01,  9.9122e-01,  9.4932e-01, -9.3752e-01, -6.2194e-01,
         -9.6798e-01, -9.7379e-01,  9.9437e-01,  9.1523e-01,  9.7675e-01,
         -7.8398e-01, -6.5248e-01,  8.6718e-01,  5.2561e-02, -7.8722e-01,
         -9.8185e-01, -9.9298e-01, -9.5604e-01,  8.3501e-01, -9.4678e-01,
         -9.9153e-01,  9.8914e-01,  9.9655e-01,  5.0410e-01, -9.7488e-01,
         -7.6809e-01,  9.9637e-01,  8.6916e-01,  9.9811e-01,  1.1015e-01,
          9.8714e-01, -8.6460e-01,  9.4981e-01, -4.9603e-01,  9.9574e-01,
         -9.7570e-01,  9.9345e-01,  9.9260e-01, -3.5343e-02,  9.1135e-01,
         -9.8033e-01, -9.7872e-02, -1.6489e-01, -1.6840e-03, -7.8168e-01,
         -2.7360e-01, -8.6780e-01, -5.7808e-01,  7.7969e-01, -9.9286e-01,
          9.8031e-01, -5.4495e-02, -1.8095e-01,  4.3502e-01,  6.3897e-01,
          9.2817e-01,  8.5180e-01, -9.9397e-01,  3.8624e-01,  7.5814e-01,
          8.7531e-01,  9.8918e-01,  9.2343e-01,  8.6657e-01, -7.7646e-01,
         -9.9130e-01,  4.2423e-01, -8.7214e-01,  5.7086e-01, -8.9156e-01,
          9.7075e-01,  9.3418e-01, -9.9337e-01,  8.2423e-01, -4.5787e-01,
          7.4177e-01,  8.9448e-01,  8.9995e-01,  1.0970e-01,  7.7133e-01,
          7.5812e-01,  9.6250e-01, -9.5516e-01,  6.0692e-01, -9.6772e-01,
          6.3918e-01,  9.4262e-01, -9.5371e-01,  9.9528e-01, -9.4898e-01,
          9.5242e-01, -8.7101e-01,  2.7965e-01,  9.4429e-01,  8.7943e-01,
         -6.3892e-01,  9.9663e-01,  8.2590e-01, -6.7797e-01, -7.9269e-01,
         -7.1029e-01, -8.1672e-01,  2.7600e-01]]), 
hidden_states=None, 
past_key_values=None,
attentions=None, 
cross_attentions=None)

猜你喜欢

转载自blog.csdn.net/weixin_44697051/article/details/121872488