如何提取模型变量

1,

def get_flat_weights(model):
weight_names = [
name for name in model.get_variable_names()
if "linear_model" in name and "Ftrl" not in name]

weight_values = [model.get_variable_value(name) for name in weight_names]

weights_flat = np.concatenate([item.flatten() for item in weight_values], axis=0)

return weights_flat

weights_flat = get_flat_weights(model)
weights_flat_l1 = get_flat_weights(model_l1)
weights_flat_l2 = get_flat_weights(model_l2)

2,

weight_mask = weights_flat != 0

weights_base = weights_flat[weight_mask]
weights_l1 = weights_flat_l1[weight_mask]
weights_l2 = weights_flat_l2[weight_mask]

猜你喜欢

转载自www.cnblogs.com/augustone/p/10506153.html