PoseNet深度网络进行6D位姿估计的训练,python3实现

0.相关github网址

1.任务背景

  • Image Matching Challenge 2023
  • 大致任务:根据图片推算相机位姿,包括3*3的旋转矩阵和3维的位置矩阵
  • 数据描述:train_labels.csv
    • dataset:数据集名字
    • scene:场景
    • image_path:图像路径
    • rotation_matrix:3*3的旋转矩阵
    • translation_vector:3维的位置矩阵
      结构如下:
      在这里插入图片描述

2.四元数与旋转矩阵的转换

#change to 四元数,https://zhuanlan.zhihu.com/p/45404840
def matrix2quaternion(m):
    #m:array
    w = ((np.trace(m) + 1) ** 0.5) / 2
    x = (m[2][1] - m[1][2]) / (4 * w)
    y = (m[0][2] - m[2][0]) / (4 * w)
    z = (m[1][0] - m[0][1]) / (4 * w)
    return w,x,y,z
def quaternion2matrix(q):
    #q:list
    w,x,y,z = q
    return np.array([[1-2*y*y-2*z*z, 2*x*y-2*z*w, 2*x*z+2*y*w],
             [2*x*y+2*z*w, 1-2*x*x-2*z*z, 2*y*z-2*x*w],
             [2*x*z-2*y*w, 2*y*z+2*x*w, 1-2*x*x-2*y*y]])

3.train_labels.csv文件处理

  • 需要根据rotation_matrix的数据计算出对应的四元数并存储到新列rotation_matrix_quaternion中:(使用列表推导式和map实现)
def m(a):
    a = a.split(';')
    a = [float(i) for i in a]
    A = np.array([[a[0],a[1],a[2]],
                [a[3],a[4],a[5]],
                [a[6],a[7],a[8]]])
    return matrix2quaternion(A)

change_train_labels = 1
if change_train_labels:
    train_labels = pd.read_csv('/kaggle/input/image-matching-challenge-2023/train/train_labels.csv')
    train_labels['rotation_matrix_quaternion'] = [i for i in map(m,train_labels['rotation_matrix'])]
    train_labels.to_csv('/kaggle/working/my_train_labels.csv')
  • '/kaggle/input/image-matching-challenge-2023/train/train_labels.csv'读入,写入'/kaggle/working/my_train_labels.csv'

4.网络模型相关

  • 构建神经网络类代码
DEFAULT_PADDING = 'SAME'


def layer(op):
    '''Decorator for composable network layers.'''

    def layer_decorated(self, *args, **kwargs):
        # Automatically set a name if not provided.
        name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
        # Figure out the layer inputs.
        if len(self.terminals) == 0:
            raise RuntimeError('No input variables found for layer %s.' % name)
        elif len(self.terminals) == 1:
            layer_input = self.terminals[0]
        else:
            layer_input = list(self.terminals)
        # Perform the operation and get the output.
        layer_output = op(self, layer_input, *args, **kwargs)
        # Add to layer LUT.
        self.layers[name] = layer_output
        # This output is now the input for the next layer.
        self.feed(layer_output)
        # Return self for chained calls.
        return self

    return layer_decorated


class Network(object):

    def __init__(self, inputs, trainable=True):
        # The input nodes for this network
        self.inputs = inputs
        # The current list of terminal nodes
        self.terminals = []
        # Mapping from layer names to layers
        self.layers = dict(inputs)
        # If true, the resulting variables are set as trainable
        self.trainable = trainable
        # Switch variable for dropout
        self.use_dropout = tf.placeholder_with_default(tf.constant(1.0),
                                                       shape=[],
                                                       name='use_dropout')
        self.setup()

    def setup(self):
        '''Construct the network. '''
        raise NotImplementedError('Must be implemented by the subclass.')

    def load(self, data_path, session, ignore_missing=False):
        '''Load network weights.
        data_path: The path to the numpy-serialized network weights
        session: The current TensorFlow session
        ignore_missing: If true, serialized weights for missing layers are ignored.
        '''
        data_dict = np.load(data_path,allow_pickle=True,encoding="latin1").item()
        for op_name in data_dict:
            with tf.variable_scope(op_name, reuse=True):
                for param_name, data in data_dict[op_name].items():
                    try:
                        var = tf.get_variable(param_name)
                        session.run(var.assign(data))
                    except ValueError:
                        if not ignore_missing:
                            raise

    def feed(self, *args):
        '''Set the input(s) for the next operation by replacing the terminal nodes.
        The arguments can be either layer names or the actual layers.
        '''
        assert len(args) != 0
        self.terminals = []
        for fed_layer in args:
            if isinstance(fed_layer, str):
                try:
                    fed_layer = self.layers[fed_layer]
                except KeyError:
                    raise KeyError('Unknown layer name fed: %s' % fed_layer)
            self.terminals.append(fed_layer)
        return self

    def get_output(self):
        '''Returns the current network output.'''
        return self.terminals[-1]

    def get_unique_name(self, prefix):
        '''Returns an index-suffixed unique name for the given prefix.
        This is used for auto-generating layer names based on the type-prefix.
        '''
        ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1
        return '%s_%d' % (prefix, ident)

    def make_var(self, name, shape):
        '''Creates a new TensorFlow variable.'''
        return tf.get_variable(name, shape, trainable=self.trainable)

    def validate_padding(self, padding):
        '''Verifies that the padding is one of the supported ones.'''
        assert padding in ('SAME', 'VALID')

    @layer
    def conv(self,
             input,
             k_h,
             k_w,
             c_o,
             s_h,
             s_w,
             name,
             relu=True,
             padding=DEFAULT_PADDING,
             group=1,
             biased=True):
        # Verify that the padding is acceptable
        self.validate_padding(padding)
        # Get the number of channels in the input
        c_i = input.get_shape()[-1]
        # Verify that the grouping parameter is valid
        assert c_i % group == 0
        assert c_o % group == 0
        # Convolution for a given input and kernel
        convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
        with tf.variable_scope(name) as scope:
            kernel = self.make_var('weights', shape=[k_h, k_w, int(int(c_i) / group), c_o])
            if group == 1:
                # This is the common-case. Convolve the input without any further complications.
                output = convolve(input, kernel)
            else:
                # Split the input into groups and then convolve each of them independently
                input_groups = tf.split(3, group, input)
                kernel_groups = tf.split(3, group, kernel)
                output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)]
                # Concatenate the groups
                output = tf.concat(3, output_groups)
            # Add the biases
            if biased:
                biases = self.make_var('biases', [c_o])
                output = tf.nn.bias_add(output, biases)
            if relu:
                # ReLU non-linearity
                output = tf.nn.relu(output, name=scope.name)
            return output

    @layer
    def relu(self, input, name):
        return tf.nn.relu(input, name=name)

    @layer
    def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.max_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

    @layer
    def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.avg_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

    @layer
    def lrn(self, input, radius, alpha, beta, name, bias=1.0):
        return tf.nn.local_response_normalization(input,
                                                  depth_radius=radius,
                                                  alpha=alpha,
                                                  beta=beta,
                                                  bias=bias,
                                                  name=name)

    @layer
    def concat(self, inputs, axis, name):
        return tf.concat(values=inputs, axis=axis, name=name)

    @layer
    def add(self, inputs, name):
        return tf.add_n(inputs, name=name)

    @layer
    def fc(self, input, num_out, name, relu=True):
        with tf.variable_scope(name) as scope:
            input_shape = input.get_shape()
            if input_shape.ndims == 4:
                # The input is spatial. Vectorize it first.
                dim = 1
                for d in input_shape[1:].as_list():
                    dim *= d
                feed_in = tf.reshape(input, [-1, dim])
            else:
                feed_in, dim = (input, input_shape[-1].value)
            weights = self.make_var('weights', shape=[dim, num_out])
            biases = self.make_var('biases', [num_out])
            op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b
            fc = op(feed_in, weights, biases, name=scope.name)
            return fc

    @layer
    def softmax(self, input, name):
        input_shape = map(lambda v: v.value, input.get_shape())
        if len(input_shape) > 2:
            # For certain models (like NiN), the singleton spatial dimensions
            # need to be explicitly squeezed, since they're not broadcast-able
            # in TensorFlow's NHWC ordering (unlike Caffe's NCHW).
            if input_shape[1] == 1 and input_shape[2] == 1:
                input = tf.squeeze(input, squeeze_dims=[1, 2])
            else:
                raise ValueError('Rank 2 tensor input expected for softmax!')
        return tf.nn.softmax(input, name)

    @layer
    def batch_normalization(self, input, name, scale_offset=True, relu=False):
        # NOTE: Currently, only inference is supported
        with tf.variable_scope(name) as scope:
            shape = [input.get_shape()[-1]]
            if scale_offset:
                scale = self.make_var('scale', shape=shape)
                offset = self.make_var('offset', shape=shape)
            else:
                scale, offset = (None, None)
            output = tf.nn.batch_normalization(
                input,
                mean=self.make_var('mean', shape=shape),
                variance=self.make_var('variance', shape=shape),
                offset=offset,
                scale=scale,
                # TODO: This is the default Caffe batch norm eps
                # Get the actual eps from parameters
                variance_epsilon=1e-5,
                name=name)
            if relu:
                output = tf.nn.relu(output)
            return output

    @layer
    def dropout(self, input, keep_prob, name):
        keep = 1 - self.use_dropout + (self.use_dropout * keep_prob)
        return tf.nn.dropout(input, keep, name=name)
  • 构建骨干网络GoogLeNet:
class GoogLeNet(Network):
    def setup(self):
        (self.feed('data')
             .conv(7, 7, 64, 2, 2, name='conv1')
             .max_pool(3, 3, 2, 2, name='pool1')
             .lrn(2, 2e-05, 0.75, name='norm1')
             .conv(1, 1, 64, 1, 1, name='reduction2')
             .conv(3, 3, 192, 1, 1, name='conv2')
             .lrn(2, 2e-05, 0.75, name='norm2')
             .max_pool(3, 3, 2, 2, name='pool2')
             .conv(1, 1, 96, 1, 1, name='icp1_reduction1')
             .conv(3, 3, 128, 1, 1, name='icp1_out1'))

        (self.feed('pool2')
             .conv(1, 1, 16, 1, 1, name='icp1_reduction2')
             .conv(5, 5, 32, 1, 1, name='icp1_out2'))

        (self.feed('pool2')
             .max_pool(3, 3, 1, 1, name='icp1_pool')
             .conv(1, 1, 32, 1, 1, name='icp1_out3'))

        (self.feed('pool2')
             .conv(1, 1, 64, 1, 1, name='icp1_out0'))

        (self.feed('icp1_out0', 
                   'icp1_out1', 
                   'icp1_out2', 
                   'icp1_out3')
             .concat(3, name='icp2_in')
             .conv(1, 1, 128, 1, 1, name='icp2_reduction1')
             .conv(3, 3, 192, 1, 1, name='icp2_out1'))

        (self.feed('icp2_in')
             .conv(1, 1, 32, 1, 1, name='icp2_reduction2')
             .conv(5, 5, 96, 1, 1, name='icp2_out2'))

        (self.feed('icp2_in')
             .max_pool(3, 3, 1, 1, name='icp2_pool')
             .conv(1, 1, 64, 1, 1, name='icp2_out3'))

        (self.feed('icp2_in')
             .conv(1, 1, 128, 1, 1, name='icp2_out0'))

        (self.feed('icp2_out0', 
                   'icp2_out1', 
                   'icp2_out2', 
                   'icp2_out3')
             .concat(3, name='icp2_out')
             .max_pool(3, 3, 2, 2, name='icp3_in')
             .conv(1, 1, 96, 1, 1, name='icp3_reduction1')
             .conv(3, 3, 208, 1, 1, name='icp3_out1'))

        (self.feed('icp3_in')
             .conv(1, 1, 16, 1, 1, name='icp3_reduction2')
             .conv(5, 5, 48, 1, 1, name='icp3_out2'))

        (self.feed('icp3_in')
             .max_pool(3, 3, 1, 1, name='icp3_pool')
             .conv(1, 1, 64, 1, 1, name='icp3_out3'))

        (self.feed('icp3_in')
             .conv(1, 1, 192, 1, 1, name='icp3_out0'))

        (self.feed('icp3_out0', 
                   'icp3_out1', 
                   'icp3_out2', 
                   'icp3_out3')
             .concat(3, name='icp3_out')
             .avg_pool(5, 5, 3, 3, padding='VALID', name='cls1_pool')
             .conv(1, 1, 128, 1, 1, name='cls1_reduction_pose')
             .fc(1024, name='cls1_fc1_pose')
             .fc(3, relu=False, name='cls1_fc_pose_xyz'))

        (self.feed('cls1_fc1_pose')
             .fc(4, relu=False, name='cls1_fc_pose_wpqr'))

        (self.feed('icp3_out')
             .conv(1, 1, 112, 1, 1, name='icp4_reduction1')
             .conv(3, 3, 224, 1, 1, name='icp4_out1'))

        (self.feed('icp3_out')
             .conv(1, 1, 24, 1, 1, name='icp4_reduction2')
             .conv(5, 5, 64, 1, 1, name='icp4_out2'))

        (self.feed('icp3_out')
             .max_pool(3, 3, 1, 1, name='icp4_pool')
             .conv(1, 1, 64, 1, 1, name='icp4_out3'))

        (self.feed('icp3_out')
             .conv(1, 1, 160, 1, 1, name='icp4_out0'))

        (self.feed('icp4_out0', 
                   'icp4_out1', 
                   'icp4_out2', 
                   'icp4_out3')
             .concat(3, name='icp4_out')
             .conv(1, 1, 128, 1, 1, name='icp5_reduction1')
             .conv(3, 3, 256, 1, 1, name='icp5_out1'))

        (self.feed('icp4_out')
             .conv(1, 1, 24, 1, 1, name='icp5_reduction2')
             .conv(5, 5, 64, 1, 1, name='icp5_out2'))

        (self.feed('icp4_out')
             .max_pool(3, 3, 1, 1, name='icp5_pool')
             .conv(1, 1, 64, 1, 1, name='icp5_out3'))

        (self.feed('icp4_out')
             .conv(1, 1, 128, 1, 1, name='icp5_out0'))

        (self.feed('icp5_out0', 
                   'icp5_out1', 
                   'icp5_out2', 
                   'icp5_out3')
             .concat(3, name='icp5_out')
             .conv(1, 1, 144, 1, 1, name='icp6_reduction1')
             .conv(3, 3, 288, 1, 1, name='icp6_out1'))

        (self.feed('icp5_out')
             .conv(1, 1, 32, 1, 1, name='icp6_reduction2')
             .conv(5, 5, 64, 1, 1, name='icp6_out2'))

        (self.feed('icp5_out')
             .max_pool(3, 3, 1, 1, name='icp6_pool')
             .conv(1, 1, 64, 1, 1, name='icp6_out3'))

        (self.feed('icp5_out')
             .conv(1, 1, 112, 1, 1, name='icp6_out0'))

        (self.feed('icp6_out0', 
                   'icp6_out1', 
                   'icp6_out2', 
                   'icp6_out3')
             .concat(3, name='icp6_out')
             .avg_pool(5, 5, 3, 3, padding='VALID', name='cls2_pool')
             .conv(1, 1, 128, 1, 1, name='cls2_reduction_pose')
             .fc(1024, name='cls2_fc1')
             .fc(3, relu=False, name='cls2_fc_pose_xyz'))

        (self.feed('cls2_fc1')
             .fc(4, relu=False, name='cls2_fc_pose_wpqr'))

        (self.feed('icp6_out')
             .conv(1, 1, 160, 1, 1, name='icp7_reduction1')
             .conv(3, 3, 320, 1, 1, name='icp7_out1'))

        (self.feed('icp6_out')
             .conv(1, 1, 32, 1, 1, name='icp7_reduction2')
             .conv(5, 5, 128, 1, 1, name='icp7_out2'))

        (self.feed('icp6_out')
             .max_pool(3, 3, 1, 1, name='icp7_pool')
             .conv(1, 1, 128, 1, 1, name='icp7_out3'))

        (self.feed('icp6_out')
             .conv(1, 1, 256, 1, 1, name='icp7_out0'))

        (self.feed('icp7_out0', 
                   'icp7_out1', 
                   'icp7_out2', 
                   'icp7_out3')
             .concat(3, name='icp7_out')
             .max_pool(3, 3, 2, 2, name='icp8_in')
             .conv(1, 1, 160, 1, 1, name='icp8_reduction1')
             .conv(3, 3, 320, 1, 1, name='icp8_out1'))

        (self.feed('icp8_in')
             .conv(1, 1, 32, 1, 1, name='icp8_reduction2')
             .conv(5, 5, 128, 1, 1, name='icp8_out2'))

        (self.feed('icp8_in')
             .max_pool(3, 3, 1, 1, name='icp8_pool')
             .conv(1, 1, 128, 1, 1, name='icp8_out3'))

        (self.feed('icp8_in')
             .conv(1, 1, 256, 1, 1, name='icp8_out0'))

        (self.feed('icp8_out0', 
                   'icp8_out1', 
                   'icp8_out2', 
                   'icp8_out3')
             .concat(3, name='icp8_out')
             .conv(1, 1, 192, 1, 1, name='icp9_reduction1')
             .conv(3, 3, 384, 1, 1, name='icp9_out1'))

        (self.feed('icp8_out')
             .conv(1, 1, 48, 1, 1, name='icp9_reduction2')
             .conv(5, 5, 128, 1, 1, name='icp9_out2'))

        (self.feed('icp8_out')
             .max_pool(3, 3, 1, 1, name='icp9_pool')
             .conv(1, 1, 128, 1, 1, name='icp9_out3'))

        (self.feed('icp8_out')
             .conv(1, 1, 384, 1, 1, name='icp9_out0'))

        (self.feed('icp9_out0', 
                   'icp9_out1', 
                   'icp9_out2', 
                   'icp9_out3')
             .concat(3, name='icp9_out')
             .avg_pool(7, 7, 1, 1, padding='VALID', name='cls3_pool')
             .fc(2048, name='cls3_fc1_pose')
             .fc(3, relu=False, name='cls3_fc_pose_xyz'))

        (self.feed('cls3_fc1_pose')
             .fc(4, relu=False, name='cls3_fc_pose_wpqr'))

5.图像预处理部分

  • PoseNet的输入图像是224*224分辨率的,加上本任务对图像尺寸视角等敏感,不适合直接放缩,所以采用中心裁剪的办法,中心裁剪函数如下:
def centeredCrop(img, output_side_length):
    height, width, depth = img.shape
    new_height = output_side_length
    new_width = output_side_length
    if height > width:
        new_height = output_side_length * height / width
    else:
        new_width = output_side_length * width / height
    height_offset = (new_height - output_side_length) / 2
    width_offset = (new_width - output_side_length) / 2
    cropped_img = img[height_offset:height_offset + output_side_length,
                        width_offset:width_offset + output_side_length]
    return cropped_img
  • 预处理函数,在这个函数中会调用上面的中心裁剪函数,并对图像的每个通道进行归一化,并完成维度转换,方便送入PyTorch的网络:
def preprocess(images):
    images_out = [] #final result
    #Resize and crop and compute mean!
    images_cropped = []
    for i in tqdm(range(len(images)):
        #print(images[i])
        X = cv2.imread(images[i])
        #X = cv2.resize(X, (455, 256))
        X = centeredCrop(X, 224)
        images_cropped.append(X)
    #compute images mean
    N = 0
    mean = np.zeros((1, 3, 224, 224))
    for X in tqdm(images_cropped):
        X = np.transpose(X,(2,0,1))
        #print(X.shape)#3,224,224
        #print(X[0,:,:].shape)#3,224
        #print(mean[0][0].shape)#224,224
        mean[0][0] += X[0,:,:]
        mean[0][1] += X[1,:,:]
        mean[0][2] += X[2,:,:]
        N += 1
    mean[0] /= N
    #Subtract mean from all images
    for X in tqdm(images_cropped):
        X = np.transpose(X,(2,0,1))
        X = X - mean
        X = np.squeeze(X)
        X = np.transpose(X, (1,2,0))
        images_out.append(X)
    return images_out
  • 如果调试的时候,为了快速验证,可以不处理全部函数,比如len(images)*0+2


#network.py
DEFAULT_PADDING = 'SAME'


def layer(op):
    '''Decorator for composable network layers.'''

    def layer_decorated(self, *args, **kwargs):
        # Automatically set a name if not provided.
        name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
        # Figure out the layer inputs.
        if len(self.terminals) == 0:
            raise RuntimeError('No input variables found for layer %s.' % name)
        elif len(self.terminals) == 1:
            layer_input = self.terminals[0]
        else:
            layer_input = list(self.terminals)
        # Perform the operation and get the output.
        layer_output = op(self, layer_input, *args, **kwargs)
        # Add to layer LUT.
        self.layers[name] = layer_output
        # This output is now the input for the next layer.
        self.feed(layer_output)
        # Return self for chained calls.
        return self

    return layer_decorated


class Network(object):

    def __init__(self, inputs, trainable=True):
        # The input nodes for this network
        self.inputs = inputs
        # The current list of terminal nodes
        self.terminals = []
        # Mapping from layer names to layers
        self.layers = dict(inputs)
        # If true, the resulting variables are set as trainable
        self.trainable = trainable
        # Switch variable for dropout
        self.use_dropout = tf.placeholder_with_default(tf.constant(1.0),
                                                       shape=[],
                                                       name='use_dropout')
        self.setup()

    def setup(self):
        '''Construct the network. '''
        raise NotImplementedError('Must be implemented by the subclass.')

    def load(self, data_path, session, ignore_missing=False):
        '''Load network weights.
        data_path: The path to the numpy-serialized network weights
        session: The current TensorFlow session
        ignore_missing: If true, serialized weights for missing layers are ignored.
        '''
        data_dict = np.load(data_path,allow_pickle=True,encoding="latin1").item()
        for op_name in data_dict:
            with tf.variable_scope(op_name, reuse=True):
                for param_name, data in data_dict[op_name].items():
                    try:
                        var = tf.get_variable(param_name)
                        session.run(var.assign(data))
                    except ValueError:
                        if not ignore_missing:
                            raise

    def feed(self, *args):
        '''Set the input(s) for the next operation by replacing the terminal nodes.
        The arguments can be either layer names or the actual layers.
        '''
        assert len(args) != 0
        self.terminals = []
        for fed_layer in args:
            if isinstance(fed_layer, str):
                try:
                    fed_layer = self.layers[fed_layer]
                except KeyError:
                    raise KeyError('Unknown layer name fed: %s' % fed_layer)
            self.terminals.append(fed_layer)
        return self

    def get_output(self):
        '''Returns the current network output.'''
        return self.terminals[-1]

    def get_unique_name(self, prefix):
        '''Returns an index-suffixed unique name for the given prefix.
        This is used for auto-generating layer names based on the type-prefix.
        '''
        ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1
        return '%s_%d' % (prefix, ident)

    def make_var(self, name, shape):
        '''Creates a new TensorFlow variable.'''
        return tf.get_variable(name, shape, trainable=self.trainable)

    def validate_padding(self, padding):
        '''Verifies that the padding is one of the supported ones.'''
        assert padding in ('SAME', 'VALID')

    @layer
    def conv(self,
             input,
             k_h,
             k_w,
             c_o,
             s_h,
             s_w,
             name,
             relu=True,
             padding=DEFAULT_PADDING,
             group=1,
             biased=True):
        # Verify that the padding is acceptable
        self.validate_padding(padding)
        # Get the number of channels in the input
        c_i = input.get_shape()[-1]
        # Verify that the grouping parameter is valid
        assert c_i % group == 0
        assert c_o % group == 0
        # Convolution for a given input and kernel
        convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
        with tf.variable_scope(name) as scope:
            kernel = self.make_var('weights', shape=[k_h, k_w, int(int(c_i) / group), c_o])
            if group == 1:
                # This is the common-case. Convolve the input without any further complications.
                output = convolve(input, kernel)
            else:
                # Split the input into groups and then convolve each of them independently
                input_groups = tf.split(3, group, input)
                kernel_groups = tf.split(3, group, kernel)
                output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)]
                # Concatenate the groups
                output = tf.concat(3, output_groups)
            # Add the biases
            if biased:
                biases = self.make_var('biases', [c_o])
                output = tf.nn.bias_add(output, biases)
            if relu:
                # ReLU non-linearity
                output = tf.nn.relu(output, name=scope.name)
            return output

    @layer
    def relu(self, input, name):
        return tf.nn.relu(input, name=name)

    @layer
    def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.max_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

    @layer
    def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.avg_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

    @layer
    def lrn(self, input, radius, alpha, beta, name, bias=1.0):
        return tf.nn.local_response_normalization(input,
                                                  depth_radius=radius,
                                                  alpha=alpha,
                                                  beta=beta,
                                                  bias=bias,
                                                  name=name)

    @layer
    def concat(self, inputs, axis, name):
        return tf.concat(values=inputs, axis=axis, name=name)

    @layer
    def add(self, inputs, name):
        return tf.add_n(inputs, name=name)

    @layer
    def fc(self, input, num_out, name, relu=True):
        with tf.variable_scope(name) as scope:
            input_shape = input.get_shape()
            if input_shape.ndims == 4:
                # The input is spatial. Vectorize it first.
                dim = 1
                for d in input_shape[1:].as_list():
                    dim *= d
                feed_in = tf.reshape(input, [-1, dim])
            else:
                feed_in, dim = (input, input_shape[-1].value)
            weights = self.make_var('weights', shape=[dim, num_out])
            biases = self.make_var('biases', [num_out])
            op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b
            fc = op(feed_in, weights, biases, name=scope.name)
            return fc

    @layer
    def softmax(self, input, name):
        input_shape = map(lambda v: v.value, input.get_shape())
        if len(input_shape) > 2:
            # For certain models (like NiN), the singleton spatial dimensions
            # need to be explicitly squeezed, since they're not broadcast-able
            # in TensorFlow's NHWC ordering (unlike Caffe's NCHW).
            if input_shape[1] == 1 and input_shape[2] == 1:
                input = tf.squeeze(input, squeeze_dims=[1, 2])
            else:
                raise ValueError('Rank 2 tensor input expected for softmax!')
        return tf.nn.softmax(input, name)

    @layer
    def batch_normalization(self, input, name, scale_offset=True, relu=False):
        # NOTE: Currently, only inference is supported
        with tf.variable_scope(name) as scope:
            shape = [input.get_shape()[-1]]
            if scale_offset:
                scale = self.make_var('scale', shape=shape)
                offset = self.make_var('offset', shape=shape)
            else:
                scale, offset = (None, None)
            output = tf.nn.batch_normalization(
                input,
                mean=self.make_var('mean', shape=shape),
                variance=self.make_var('variance', shape=shape),
                offset=offset,
                scale=scale,
                # TODO: This is the default Caffe batch norm eps
                # Get the actual eps from parameters
                variance_epsilon=1e-5,
                name=name)
            if relu:
                output = tf.nn.relu(output)
            return output

    @layer
    def dropout(self, input, keep_prob, name):
        keep = 1 - self.use_dropout + (self.use_dropout * keep_prob)
        return tf.nn.dropout(input, keep, name=name)

def centeredCrop(img, output_side_length):
    height, width, depth = img.shape
    new_height = output_side_length
    new_width = output_side_length
    if height > width:
        new_height = output_side_length * height / width
    else:
        new_width = output_side_length * width / height
    height_offset = (new_height - output_side_length) / 2
    width_offset = (new_width - output_side_length) / 2
    cropped_img = img[height_offset:height_offset + output_side_length,
                        width_offset:width_offset + output_side_length]
    return cropped_img
def preprocess(images):
    images_out = [] #final result
    #Resize and crop and compute mean!
    images_cropped = []
    for i in tqdm(range(len(images)*0+2)):
        #print(images[i])
        X = cv2.imread(images[i])
        #X = cv2.resize(X, (455, 256))
        X = centeredCrop(X, 224)
        images_cropped.append(X)
    #compute images mean
    N = 0
    mean = np.zeros((1, 3, 224, 224))
    for X in tqdm(images_cropped):
        X = np.transpose(X,(2,0,1))
        #print(X.shape)#3,224,224
        #print(X[0,:,:].shape)#3,224
        #print(mean[0][0].shape)#224,224
        mean[0][0] += X[0,:,:]
        mean[0][1] += X[1,:,:]
        mean[0][2] += X[2,:,:]
        N += 1
    mean[0] /= N
    #Subtract mean from all images
    for X in tqdm(images_cropped):
        X = np.transpose(X,(2,0,1))
        X = X - mean
        X = np.squeeze(X)
        X = np.transpose(X, (1,2,0))
        images_out.append(X)
    return images_out


6.数据加载

  • 基本配置设置(my_train_labels是之前文件处理写入的路径)
batch_size = 75
max_iterations = 30000
# Set this path to your dataset directory
my_train_labels = '/kaggle/working/my_train_labels.csv'
  • 创建datasource类,成对存储数据
class datasource(object):
    def __init__(self, images, poses):
        self.images = images
        self.poses = poses
  • 获取单一数据
def gen_data(source):
    while True:
        indices = list(range(len(source.images)))
        random.shuffle(indices)
        for i in indices:
            image = source.images[i]
            pose_x = source.poses[i][0:3]
            pose_q = source.poses[i][3:7]
            yield image, pose_x, pose_q
  • 批量获取数据
def gen_data_batch(source):
    data_gen = gen_data(source)
    while True:
        image_batch = []
        pose_x_batch = []
        pose_q_batch = []
        for _ in range(batch_size):
            image, pose_x, pose_q = next(data_gen)
            image_batch.append(image)
            pose_x_batch.append(pose_x)
            pose_q_batch.append(pose_q)
        yield np.array(image_batch), np.array(pose_x_batch), np.array(pose_q_batch)
  • 获取数据的最终函数:(中间的路径需要根据训练图像所处位置进行更改,按照任务背景的目录结构,是train_labels.csv所处文件夹)
def get_data():
    poses = []
    images = []
    for i in pd.read_csv(my_train_labels).itertuples():
        #i[4]:image_path
        #i[6]:xyz需要根据
        #i[7]:四元数
        #print(i[4])
        #print(i[7])
        p0,p1,p2 = i[6].split(';')
        p3,p4,p5,p6 = i[7].split('(')[1].split(')')[0].split(',')
        p0 = float(p0)
        p1 = float(p1)
        p2 = float(p2)
        p3 = float(p3)
        p4 = float(p4)
        p5 = float(p5)
        p6 = float(p6)
        poses.append((p0,p1,p2,p3,p4,p5,p6))
        images.append('/kaggle/input/image-matching-challenge-2023/train/' + i[4])
        #print(poses,images)
    images = preprocess(images)
    return datasource(images, poses)

7.网络和数据容器的准备

images = tf.placeholder(tf.float32, [batch_size, 224, 224, 3])
poses_x = tf.placeholder(tf.float32, [batch_size, 3])
poses_q = tf.placeholder(tf.float32, [batch_size, 4])
datasource = get_data()

net = GoogLeNet({
    
    'data': images})

p1_x = net.layers['cls1_fc_pose_xyz']
p1_q = net.layers['cls1_fc_pose_wpqr']
p2_x = net.layers['cls2_fc_pose_xyz']
p2_q = net.layers['cls2_fc_pose_wpqr']
p3_x = net.layers['cls3_fc_pose_xyz']
p3_q = net.layers['cls3_fc_pose_wpqr']

l1_x = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p1_x, poses_x)))) * 0.3
l1_q = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p1_q, poses_q)))) * 150
l2_x = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p2_x, poses_x)))) * 0.3
l2_q = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p2_q, poses_q)))) * 150
l3_x = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p3_x, poses_x)))) * 1
l3_q = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p3_q, poses_q)))) * 500

loss = l1_x + l1_q + l2_x + l2_q + l3_x + l3_q
opt = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999, epsilon=0.00000001, use_locking=False, name='Adam').minimize(loss)

# Set GPU options
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6833)

init = tf.global_variables_initializer()
saver = tf.train.Saver()
outputFile = "PoseNet.ckpt"
  • 这里运行第二遍会报错,因为网络在内存中已经构建起来了

8.开始训练

  • 下面的代码cpu、gpu环境都可以,每20轮打印一下损失,每500轮保存一下权重
  • 加载的预训练权重放在了'/kaggle/input/tensorflow-posenet-master/tensorflow-posenet-master/posenet.npy',是官方caffe权重转换过来的,下载链接
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    # Load the data
    sess.run(init)
    net.load('/kaggle/input/tensorflow-posenet-master/tensorflow-posenet-master/posenet.npy', sess)

    data_gen = gen_data_batch(datasource)
    for i in range(max_iterations):
        np_images, np_poses_x, np_poses_q = next(data_gen)
        feed = {
    
    images: np_images, poses_x: np_poses_x, poses_q: np_poses_q}

        sess.run(opt, feed_dict=feed)
        np_loss = sess.run(loss, feed_dict=feed)
        if i % 20 == 0:
            print("iteration: " + str(i) + "\n\t" + "Loss is: " + str(np_loss))
        if i % 500 == 0:
            saver.save(sess, outputFile)
            print("Intermediate file saved at: " + outputFile)
    saver.save(sess, outputFile)
    print("Intermediate file saved at: " + outputFile)
  • 效果如下:
    在这里插入图片描述## 9.完整代码如下
  • 完整notebook下载链接
  • tensorflow权重文件下载地址
  • 完整代码如下:
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import random
import cv2
from tqdm import tqdm

change_train_labels = 1
def matrix2quaternion(m):
    #m:array
    w = ((np.trace(m) + 1) ** 0.5) / 2
    x = (m[2][1] - m[1][2]) / (4 * w)
    y = (m[0][2] - m[2][0]) / (4 * w)
    z = (m[1][0] - m[0][1]) / (4 * w)
    return w,x,y,z
def quaternion2matrix(q):
    #q:list
    w,x,y,z = q
    return np.array([[1-2*y*y-2*z*z, 2*x*y-2*z*w, 2*x*z+2*y*w],
             [2*x*y+2*z*w, 1-2*x*x-2*z*z, 2*y*z-2*x*w],
             [2*x*z-2*y*w, 2*y*z+2*x*w, 1-2*x*x-2*y*y]])

def m(a):
    a = a.split(';')
    a = [float(i) for i in a]
    A = np.array([[a[0],a[1],a[2]],
                [a[3],a[4],a[5]],
                [a[6],a[7],a[8]]])
    return matrix2quaternion(A)

if change_train_labels:
    train_labels = pd.read_csv('/kaggle/input/image-matching-challenge-2023/train/train_labels.csv')
    train_labels['rotation_matrix_quaternion'] = [i for i in map(m,train_labels['rotation_matrix'])]
    train_labels.to_csv('/kaggle/working/my_train_labels.csv')

DEFAULT_PADDING = 'SAME'


def layer(op):
    '''Decorator for composable network layers.'''

    def layer_decorated(self, *args, **kwargs):
        # Automatically set a name if not vided.
        name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
        # Figure out the layer inputs.
        if len(self.terminals) == 0:
            raise RuntimeError('No input variables found for layer %s.' % name)
        elif len(self.terminals) == 1:
            layer_input = self.terminals[0]
        else:
            layer_input = list(self.terminals)
        # Perform the operation and get the output.
        layer_output = op(self, layer_input, *args, **kwargs)
        # Add to layer LUT.
        self.layers[name] = layer_output
        # This output is now the input for the next layer.
        self.feed(layer_output)
        # Return self for chained calls.
        return self

    return layer_decorated


class Network(object):

    def __init__(self, inputs, trainable=True):
        # The input nodes for this network
        self.inputs = inputs
        # The current list of terminal nodes
        self.terminals = []
        # Mapping from layer names to layers
        self.layers = dict(inputs)
        # If true, the resulting variables are set as trainable
        self.trainable = trainable
        # Switch variable for dropout
        self.use_dropout = tf.placeholder_with_default(tf.constant(1.0),
                                                       shape=[],
                                                       name='use_dropout')
        self.setup()

    def setup(self):
        '''Construct the network. '''
        raise NotImplementedError('Must be implemented by the subclass.')

    def load(self, data_path, session, ignore_missing=False):
        '''Load network weights.
        data_path: The path to the numpy-serialized network weights
        session: The current TensorFlow session
        ignore_missing: If true, serialized weights for missing layers are ignored.
        '''
        data_dict = np.load(data_path,allow_pickle=True,encoding="latin1").item()
        for op_name in data_dict:
            with tf.variable_scope(op_name, reuse=True):
                for param_name, data in data_dict[op_name].items():
                    try:
                        var = tf.get_variable(param_name)
                        session.run(var.assign(data))
                    except ValueError:
                        if not ignore_missing:
                            raise

    def feed(self, *args):
        '''Set the input(s) for the next operation by replacing the terminal nodes.
        The arguments can be either layer names or the actual layers.
        '''
        assert len(args) != 0
        self.terminals = []
        for fed_layer in args:
            if isinstance(fed_layer, str):
                try:
                    fed_layer = self.layers[fed_layer]
                except KeyError:
                    raise KeyError('Unknown layer name fed: %s' % fed_layer)
            self.terminals.append(fed_layer)
        return self

    def get_output(self):
        '''Returns the current network output.'''
        return self.terminals[-1]

    def get_unique_name(self, prefix):
        '''Returns an index-suffixed unique name for the given prefix.
        This is used for auto-generating layer names based on the type-prefix.
        '''
        ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1
        return '%s_%d' % (prefix, ident)

    def make_var(self, name, shape):
        '''Creates a new TensorFlow variable.'''
        return tf.get_variable(name, shape, trainable=self.trainable)

    def validate_padding(self, padding):
        '''Verifies that the padding is one of the supported ones.'''
        assert padding in ('SAME', 'VALID')

    @layer
    def conv(self,
             input,
             k_h,
             k_w,
             c_o,
             s_h,
             s_w,
             name,
             relu=True,
             padding=DEFAULT_PADDING,
             group=1,
             biased=True):
        # Verify that the padding is acceptable
        self.validate_padding(padding)
        # Get the number of channels in the input
        c_i = input.get_shape()[-1]
        # Verify that the grouping parameter is valid
        assert c_i % group == 0
        assert c_o % group == 0
        # Convolution for a given input and kernel
        convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
        with tf.variable_scope(name) as scope:
            kernel = self.make_var('weights', shape=[k_h, k_w, int(int(c_i) / group), c_o])
            if group == 1:
                # This is the common-case. Convolve the input without any further complications.
                output = convolve(input, kernel)
            else:
                # Split the input into groups and then convolve each of them independently
                input_groups = tf.split(3, group, input)
                kernel_groups = tf.split(3, group, kernel)
                output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)]
                # Concatenate the groups
                output = tf.concat(3, output_groups)
            # Add the biases
            if biased:
                biases = self.make_var('biases', [c_o])
                output = tf.nn.bias_add(output, biases)
            if relu:
                # ReLU non-linearity
                output = tf.nn.relu(output, name=scope.name)
            return output

    @layer
    def relu(self, input, name):
        return tf.nn.relu(input, name=name)

    @layer
    def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.max_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

    @layer
    def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
        self.validate_padding(padding)
        return tf.nn.avg_pool(input,
                              ksize=[1, k_h, k_w, 1],
                              strides=[1, s_h, s_w, 1],
                              padding=padding,
                              name=name)

    @layer
    def lrn(self, input, radius, alpha, beta, name, bias=1.0):
        return tf.nn.local_response_normalization(input,
                                                  depth_radius=radius,
                                                  alpha=alpha,
                                                  beta=beta,
                                                  bias=bias,
                                                  name=name)

    @layer
    def concat(self, inputs, axis, name):
        return tf.concat(values=inputs, axis=axis, name=name)

    @layer
    def add(self, inputs, name):
        return tf.add_n(inputs, name=name)

    @layer
    def fc(self, input, num_out, name, relu=True):
        with tf.variable_scope(name) as scope:
            input_shape = input.get_shape()
            if input_shape.ndims == 4:
                # The input is spatial. Vectorize it first.
                dim = 1
                for d in input_shape[1:].as_list():
                    dim *= d
                feed_in = tf.reshape(input, [-1, dim])
            else:
                feed_in, dim = (input, input_shape[-1].value)
            weights = self.make_var('weights', shape=[dim, num_out])
            biases = self.make_var('biases', [num_out])
            op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b
            fc = op(feed_in, weights, biases, name=scope.name)
            return fc

    @layer
    def softmax(self, input, name):
        input_shape = map(lambda v: v.value, input.get_shape())
        if len(input_shape) > 2:
            # For certain models (like NiN), the singleton spatial dimensions
            # need to be explicitly squeezed, since they're not broadcast-able
            # in TensorFlow's NHWC ordering (unlike Caffe's NCHW).
            if input_shape[1] == 1 and input_shape[2] == 1:
                input = tf.squeeze(input, squeeze_dims=[1, 2])
            else:
                raise ValueError('Rank 2 tensor input expected for softmax!')
        return tf.nn.softmax(input, name)

    @layer
    def batch_normalization(self, input, name, scale_offset=True, relu=False):
        # NOTE: Currently, only inference is supported
        with tf.variable_scope(name) as scope:
            shape = [input.get_shape()[-1]]
            if scale_offset:
                scale = self.make_var('scale', shape=shape)
                offset = self.make_var('offset', shape=shape)
            else:
                scale, offset = (None, None)
            output = tf.nn.batch_normalization(
                input,
                mean=self.make_var('mean', shape=shape),
                variance=self.make_var('variance', shape=shape),
                offset=offset,
                scale=scale,
                # TODO: This is the default Caffe batch norm eps
                # Get the actual eps from parameters
                variance_epsilon=1e-5,
                name=name)
            if relu:
                output = tf.nn.relu(output)
            return output

    @layer
    def dropout(self, input, keep_prob, name):
        keep = 1 - self.use_dropout + (self.use_dropout * keep_prob)
        return tf.nn.dropout(input, keep, name=name)

def centeredCrop(img, output_side_length):
    height, width, depth = img.shape
    new_height = output_side_length
    new_width = output_side_length
    if height > width:
        new_height = output_side_length * height / width
    else:
        new_width = output_side_length * width / height
    height_offset = (new_height - output_side_length) / 2
    width_offset = (new_width - output_side_length) / 2
    cropped_img = img[height_offset:height_offset + output_side_length,
                        width_offset:width_offset + output_side_length]
    return cropped_img
def preprocess(images):
    images_out = [] #final result
    #Resize and crop and compute mean!
    images_cropped = []
    for i in tqdm(range(len(images))):
        #print(images[i])
        X = cv2.imread(images[i])
        #X = cv2.resize(X, (455, 256))
        X = centeredCrop(X, 224)
        images_cropped.append(X)
    #compute images mean
    N = 0
    mean = np.zeros((1, 3, 224, 224))
    for X in tqdm(images_cropped):
        X = np.transpose(X,(2,0,1))
        #print(X.shape)#3,224,224
        #print(X[0,:,:].shape)#3,224
        #print(mean[0][0].shape)#224,224
        mean[0][0] += X[0,:,:]
        mean[0][1] += X[1,:,:]
        mean[0][2] += X[2,:,:]
        N += 1
    mean[0] /= N
    #Subtract mean from all images
    for X in tqdm(images_cropped):
        X = np.transpose(X,(2,0,1))
        X = X - mean
        X = np.squeeze(X)
        X = np.transpose(X, (1,2,0))
        images_out.append(X)
    return images_out

class GoogLeNet(Network):
    def setup(self):
        (self.feed('data')
             .conv(7, 7, 64, 2, 2, name='conv1')
             .max_pool(3, 3, 2, 2, name='pool1')
             .lrn(2, 2e-05, 0.75, name='norm1')
             .conv(1, 1, 64, 1, 1, name='reduction2')
             .conv(3, 3, 192, 1, 1, name='conv2')
             .lrn(2, 2e-05, 0.75, name='norm2')
             .max_pool(3, 3, 2, 2, name='pool2')
             .conv(1, 1, 96, 1, 1, name='icp1_reduction1')
             .conv(3, 3, 128, 1, 1, name='icp1_out1'))

        (self.feed('pool2')
             .conv(1, 1, 16, 1, 1, name='icp1_reduction2')
             .conv(5, 5, 32, 1, 1, name='icp1_out2'))

        (self.feed('pool2')
             .max_pool(3, 3, 1, 1, name='icp1_pool')
             .conv(1, 1, 32, 1, 1, name='icp1_out3'))

        (self.feed('pool2')
             .conv(1, 1, 64, 1, 1, name='icp1_out0'))

        (self.feed('icp1_out0', 
                   'icp1_out1', 
                   'icp1_out2', 
                   'icp1_out3')
             .concat(3, name='icp2_in')
             .conv(1, 1, 128, 1, 1, name='icp2_reduction1')
             .conv(3, 3, 192, 1, 1, name='icp2_out1'))

        (self.feed('icp2_in')
             .conv(1, 1, 32, 1, 1, name='icp2_reduction2')
             .conv(5, 5, 96, 1, 1, name='icp2_out2'))

        (self.feed('icp2_in')
             .max_pool(3, 3, 1, 1, name='icp2_pool')
             .conv(1, 1, 64, 1, 1, name='icp2_out3'))

        (self.feed('icp2_in')
             .conv(1, 1, 128, 1, 1, name='icp2_out0'))

        (self.feed('icp2_out0', 
                   'icp2_out1', 
                   'icp2_out2', 
                   'icp2_out3')
             .concat(3, name='icp2_out')
             .max_pool(3, 3, 2, 2, name='icp3_in')
             .conv(1, 1, 96, 1, 1, name='icp3_reduction1')
             .conv(3, 3, 208, 1, 1, name='icp3_out1'))

        (self.feed('icp3_in')
             .conv(1, 1, 16, 1, 1, name='icp3_reduction2')
             .conv(5, 5, 48, 1, 1, name='icp3_out2'))

        (self.feed('icp3_in')
             .max_pool(3, 3, 1, 1, name='icp3_pool')
             .conv(1, 1, 64, 1, 1, name='icp3_out3'))

        (self.feed('icp3_in')
             .conv(1, 1, 192, 1, 1, name='icp3_out0'))

        (self.feed('icp3_out0', 
                   'icp3_out1', 
                   'icp3_out2', 
                   'icp3_out3')
             .concat(3, name='icp3_out')
             .avg_pool(5, 5, 3, 3, padding='VALID', name='cls1_pool')
             .conv(1, 1, 128, 1, 1, name='cls1_reduction_pose')
             .fc(1024, name='cls1_fc1_pose')
             .fc(3, relu=False, name='cls1_fc_pose_xyz'))

        (self.feed('cls1_fc1_pose')
             .fc(4, relu=False, name='cls1_fc_pose_wpqr'))

        (self.feed('icp3_out')
             .conv(1, 1, 112, 1, 1, name='icp4_reduction1')
             .conv(3, 3, 224, 1, 1, name='icp4_out1'))

        (self.feed('icp3_out')
             .conv(1, 1, 24, 1, 1, name='icp4_reduction2')
             .conv(5, 5, 64, 1, 1, name='icp4_out2'))

        (self.feed('icp3_out')
             .max_pool(3, 3, 1, 1, name='icp4_pool')
             .conv(1, 1, 64, 1, 1, name='icp4_out3'))

        (self.feed('icp3_out')
             .conv(1, 1, 160, 1, 1, name='icp4_out0'))

        (self.feed('icp4_out0', 
                   'icp4_out1', 
                   'icp4_out2', 
                   'icp4_out3')
             .concat(3, name='icp4_out')
             .conv(1, 1, 128, 1, 1, name='icp5_reduction1')
             .conv(3, 3, 256, 1, 1, name='icp5_out1'))

        (self.feed('icp4_out')
             .conv(1, 1, 24, 1, 1, name='icp5_reduction2')
             .conv(5, 5, 64, 1, 1, name='icp5_out2'))

        (self.feed('icp4_out')
             .max_pool(3, 3, 1, 1, name='icp5_pool')
             .conv(1, 1, 64, 1, 1, name='icp5_out3'))

        (self.feed('icp4_out')
             .conv(1, 1, 128, 1, 1, name='icp5_out0'))

        (self.feed('icp5_out0', 
                   'icp5_out1', 
                   'icp5_out2', 
                   'icp5_out3')
             .concat(3, name='icp5_out')
             .conv(1, 1, 144, 1, 1, name='icp6_reduction1')
             .conv(3, 3, 288, 1, 1, name='icp6_out1'))

        (self.feed('icp5_out')
             .conv(1, 1, 32, 1, 1, name='icp6_reduction2')
             .conv(5, 5, 64, 1, 1, name='icp6_out2'))

        (self.feed('icp5_out')
             .max_pool(3, 3, 1, 1, name='icp6_pool')
             .conv(1, 1, 64, 1, 1, name='icp6_out3'))

        (self.feed('icp5_out')
             .conv(1, 1, 112, 1, 1, name='icp6_out0'))

        (self.feed('icp6_out0', 
                   'icp6_out1', 
                   'icp6_out2', 
                   'icp6_out3')
             .concat(3, name='icp6_out')
             .avg_pool(5, 5, 3, 3, padding='VALID', name='cls2_pool')
             .conv(1, 1, 128, 1, 1, name='cls2_reduction_pose')
             .fc(1024, name='cls2_fc1')
             .fc(3, relu=False, name='cls2_fc_pose_xyz'))

        (self.feed('cls2_fc1')
             .fc(4, relu=False, name='cls2_fc_pose_wpqr'))

        (self.feed('icp6_out')
             .conv(1, 1, 160, 1, 1, name='icp7_reduction1')
             .conv(3, 3, 320, 1, 1, name='icp7_out1'))

        (self.feed('icp6_out')
             .conv(1, 1, 32, 1, 1, name='icp7_reduction2')
             .conv(5, 5, 128, 1, 1, name='icp7_out2'))

        (self.feed('icp6_out')
             .max_pool(3, 3, 1, 1, name='icp7_pool')
             .conv(1, 1, 128, 1, 1, name='icp7_out3'))

        (self.feed('icp6_out')
             .conv(1, 1, 256, 1, 1, name='icp7_out0'))

        (self.feed('icp7_out0', 
                   'icp7_out1', 
                   'icp7_out2', 
                   'icp7_out3')
             .concat(3, name='icp7_out')
             .max_pool(3, 3, 2, 2, name='icp8_in')
             .conv(1, 1, 160, 1, 1, name='icp8_reduction1')
             .conv(3, 3, 320, 1, 1, name='icp8_out1'))

        (self.feed('icp8_in')
             .conv(1, 1, 32, 1, 1, name='icp8_reduction2')
             .conv(5, 5, 128, 1, 1, name='icp8_out2'))

        (self.feed('icp8_in')
             .max_pool(3, 3, 1, 1, name='icp8_pool')
             .conv(1, 1, 128, 1, 1, name='icp8_out3'))

        (self.feed('icp8_in')
             .conv(1, 1, 256, 1, 1, name='icp8_out0'))

        (self.feed('icp8_out0', 
                   'icp8_out1', 
                   'icp8_out2', 
                   'icp8_out3')
             .concat(3, name='icp8_out')
             .conv(1, 1, 192, 1, 1, name='icp9_reduction1')
             .conv(3, 3, 384, 1, 1, name='icp9_out1'))

        (self.feed('icp8_out')
             .conv(1, 1, 48, 1, 1, name='icp9_reduction2')
             .conv(5, 5, 128, 1, 1, name='icp9_out2'))

        (self.feed('icp8_out')
             .max_pool(3, 3, 1, 1, name='icp9_pool')
             .conv(1, 1, 128, 1, 1, name='icp9_out3'))

        (self.feed('icp8_out')
             .conv(1, 1, 384, 1, 1, name='icp9_out0'))

        (self.feed('icp9_out0', 
                   'icp9_out1', 
                   'icp9_out2', 
                   'icp9_out3')
             .concat(3, name='icp9_out')
             .avg_pool(7, 7, 1, 1, padding='VALID', name='cls3_pool')
             .fc(2048, name='cls3_fc1_pose')
             .fc(3, relu=False, name='cls3_fc_pose_xyz'))

        (self.feed('cls3_fc1_pose')
             .fc(4, relu=False, name='cls3_fc_pose_wpqr'))

batch_size = 75
max_iterations = 30000
# Set this path to your dataset directory
directory = 'path_to_datasets/KingsCollege/'
dataset = 'dataset_train.txt'
my_train_labels = '/kaggle/working/my_train_labels.csv'

class datasource(object):
    def __init__(self, images, poses):
        self.images = images
        self.poses = poses

def centeredCrop(img, output_side_length):
    height, width, depth = img.shape
    new_height = output_side_length
    new_width = output_side_length
    if height > width:
        new_height = output_side_length * height / width
    else:
        new_width = output_side_length * width / height
    height_offset = (new_height - output_side_length) / 2
    width_offset = (new_width - output_side_length) / 2
    cropped_img = img[int(height_offset):int(height_offset + output_side_length),
                        int(width_offset):int(width_offset + output_side_length)]
    return cropped_img

def gen_data(source):
    while True:
        indices = list(range(len(source.images)))
        random.shuffle(indices)
        for i in indices:
            image = source.images[i]
            pose_x = source.poses[i][0:3]
            pose_q = source.poses[i][3:7]
            yield image, pose_x, pose_q

def gen_data_batch(source):
    data_gen = gen_data(source)
    while True:
        image_batch = []
        pose_x_batch = []
        pose_q_batch = []
        for _ in range(batch_size):
            image, pose_x, pose_q = next(data_gen)
            image_batch.append(image)
            pose_x_batch.append(pose_x)
            pose_q_batch.append(pose_q)
        yield np.array(image_batch), np.array(pose_x_batch), np.array(pose_q_batch)

def get_data():
    poses = []
    images = []
    for i in pd.read_csv(my_train_labels).itertuples():
        #i[4]:image_path
        #i[6]:xyz
        #i[7]:四元数
        #print(i[4])
        #print(i[7])
        p0,p1,p2 = i[6].split(';')
        p3,p4,p5,p6 = i[7].split('(')[1].split(')')[0].split(',')
        p0 = float(p0)
        p1 = float(p1)
        p2 = float(p2)
        p3 = float(p3)
        p4 = float(p4)
        p5 = float(p5)
        p6 = float(p6)
        poses.append((p0,p1,p2,p3,p4,p5,p6))
        images.append('/kaggle/input/image-matching-challenge-2023/train/' + i[4])
        #print(poses,images)
    images = preprocess(images)
    return datasource(images, poses)


images = tf.placeholder(tf.float32, [batch_size, 224, 224, 3])
poses_x = tf.placeholder(tf.float32, [batch_size, 3])
poses_q = tf.placeholder(tf.float32, [batch_size, 4])
datasource = get_data()

net = GoogLeNet({
    
    'data': images})

p1_x = net.layers['cls1_fc_pose_xyz']
p1_q = net.layers['cls1_fc_pose_wpqr']
p2_x = net.layers['cls2_fc_pose_xyz']
p2_q = net.layers['cls2_fc_pose_wpqr']
p3_x = net.layers['cls3_fc_pose_xyz']
p3_q = net.layers['cls3_fc_pose_wpqr']

l1_x = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p1_x, poses_x)))) * 0.3
l1_q = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p1_q, poses_q)))) * 150
l2_x = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p2_x, poses_x)))) * 0.3
l2_q = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p2_q, poses_q)))) * 150
l3_x = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p3_x, poses_x)))) * 1
l3_q = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(p3_q, poses_q)))) * 500

loss = l1_x + l1_q + l2_x + l2_q + l3_x + l3_q
opt = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999, epsilon=0.00000001, use_locking=False, name='Adam').minimize(loss)

# Set GPU options
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6833)

init = tf.global_variables_initializer()
saver = tf.train.Saver()
outputFile = "PoseNet.ckpt"


with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    # Load the data
    sess.run(init)
    net.load('/kaggle/input/tensorflow-posenet-master/tensorflow-posenet-master/posenet.npy', sess)

    data_gen = gen_data_batch(datasource)
    for i in range(max_iterations):
        np_images, np_poses_x, np_poses_q = next(data_gen)
        feed = {
    
    images: np_images, poses_x: np_poses_x, poses_q: np_poses_q}

        sess.run(opt, feed_dict=feed)
        np_loss = sess.run(loss, feed_dict=feed)
        if i % 20 == 0:
            print("iteration: " + str(i) + "\n\t" + "Loss is: " + str(np_loss))
        if i % 500 == 0:
            saver.save(sess, outputFile)
            print("Intermediate file saved at: " + outputFile)
    saver.save(sess, outputFile)
    print("Intermediate file saved at: " + outputFile)

9.环境版本

  • !pip list
Package                                Version              Editable project location
-------------------------------------- -------------------- -------------------------
absl-py                                1.4.0
accelerate                             0.12.0
access                                 1.1.9
affine                                 2.4.0
aiobotocore                            2.5.0
aiofiles                               22.1.0
aiohttp                                3.8.4
aiohttp-cors                           0.7.0
aioitertools                           0.11.0
aiorwlock                              1.3.0
aiosignal                              1.3.1
aiosqlite                              0.19.0
albumentations                         1.3.0
alembic                                1.11.1
altair                                 5.0.0
annoy                                  1.17.2
ansiwrap                               0.8.4
anyio                                  3.6.2
apache-beam                            2.46.0
aplus                                  0.11.0
appdirs                                1.4.4
argon2-cffi                            21.3.0
argon2-cffi-bindings                   21.2.0
array-record                           0.2.0
arrow                                  1.2.3
arviz                                  0.12.1
astroid                                2.15.5
astropy                                5.3
asttokens                              2.2.1
astunparse                             1.6.3
async-timeout                          4.0.2
atpublic                               3.1.1
attrs                                  23.1.0
audioread                              3.0.0
autopep8                               2.0.2
Babel                                  2.12.1
backcall                               0.2.0
backoff                                2.2.1
backports.functools-lru-cache          1.6.4
bayesian-optimization                  1.4.3
bayespy                                0.5.25
beatrix-jupyterlab                     2023.58.190319
beautifulsoup4                         4.12.2
bidict                                 0.22.1
biopython                              1.81
blake3                                 0.2.1
bleach                                 6.0.0
blessed                                1.20.0
blinker                                1.6.2
blis                                   0.7.9
blosc2                                 2.0.0
bokeh                                  2.4.3
boltons                                23.0.0
Boruta                                 0.3
boto3                                  1.26.100
botocore                               1.29.76
bq-helper                              0.4.1                /src/bq-helper
bqplot                                 0.12.39
branca                                 0.6.0
brewer2mpl                             1.4.1
brotlipy                               0.7.0
cached-property                        1.5.2
cachetools                             4.2.4
Cartopy                                0.21.1
catalogue                              2.0.8
catalyst                               22.4
catboost                               1.2
category-encoders                      2.6.1
certifi                                2023.5.7
cesium                                 0.12.1
cffi                                   1.15.1
cftime                                 1.6.2
charset-normalizer                     2.1.1
chex                                   0.1.7
cleverhans                             4.0.0
click                                  8.1.3
click-plugins                          1.1.1
cligj                                  0.7.2
cloud-tpu-client                       0.10
cloud-tpu-profiler                     2.4.0
cloudpickle                            2.2.1
cmaes                                  0.9.1
cmdstanpy                              1.1.0
cmudict                                1.0.13
colorama                               0.4.6
colorcet                               3.0.1
colorful                               0.5.5
colorlog                               6.7.0
colorlover                             0.3.0
comm                                   0.1.3
commonmark                             0.9.1
conda                                  23.3.1
conda-content-trust                    0+unknown
conda-package-handling                 2.0.2
conda_package_streaming                0.7.0
confection                             0.0.4
contextily                             1.3.0
contourpy                              1.0.7
convertdate                            2.4.0
crcmod                                 1.7
cryptography                           40.0.2
cubinlinker                            0.2.2
cuda-python                            11.8.1
cudf                                   23.4.1
cufflinks                              0.17.3
cuml                                   23.4.1
cupy                                   11.6.0
CVXcanon                               0.1.2
cycler                                 0.11.0
cymem                                  2.0.7
cysignals                              1.11.2
Cython                                 0.29.34
cytoolz                                0.12.0
daal                                   2023.1.1
daal4py                                2023.1.1
dask                                   2023.5.0
dask-cuda                              23.4.0
dask-cudf                              23.4.1
dataclasses                            0.8
dataclasses-json                       0.5.7
datasets                               2.1.0
datashader                             0.14.4
datashape                              0.5.2
datatile                               1.0.3
db-dtypes                              1.1.1
deap                                   1.3.3
debugpy                                1.6.7
decorator                              5.1.1
defusedxml                             0.7.1
Delorean                               1.0.0
deprecat                               2.1.1
Deprecated                             1.2.13
deprecation                            2.1.0
descartes                              1.1.0
dill                                   0.3.6
dipy                                   1.7.0
distlib                                0.3.6
distributed                            2023.3.2.1
dm-tree                                0.1.8
docker                                 6.1.1
docker-pycreds                         0.4.0
docopt                                 0.6.2
docstring-parser                       0.15
docstring-to-markdown                  0.12
docutils                               0.20.1
earthengine-api                        0.1.354
easydict                               1.10
easyocr                                1.6.2
ecos                                   2.0.12
eli5                                   0.13.0
emoji                                  2.2.0
en-core-web-lg                         3.5.0
en-core-web-sm                         3.5.0
entrypoints                            0.4
ephem                                  4.1.4
esda                                   2.4.3
essentia                               2.1b6.dev1034
et-xmlfile                             1.1.0
etils                                  1.2.0
executing                              1.2.0
explainable-ai-sdk                     1.3.3
fastai                                 2.7.12
fastapi                                0.95.1
fastavro                               1.7.4
fastcore                               1.5.29
fastdownload                           0.0.7
fasteners                              0.18
fastjsonschema                         2.16.3
fastprogress                           1.0.3
fastrlock                              0.8
fasttext                               0.9.2
fbpca                                  1.0
feather-format                         0.4.1
featuretools                           1.26.0
filelock                               3.12.0
Fiona                                  1.8.22
fire                                   0.5.0
fitter                                 1.5.2
flake8                                 6.0.0
flashtext                              2.7
Flask                                  2.3.2
flatbuffers                            23.3.3
flax                                   0.6.10
flit_core                              3.8.0
folium                                 0.14.0
fonttools                              4.39.3
fqdn                                   1.5.1
frozendict                             2.3.8
frozenlist                             1.3.3
fsspec                                 2023.5.0
funcy                                  2.0
fury                                   0.9.0
future                                 0.18.3
fuzzywuzzy                             0.18.0
gast                                   0.4.0
gatspy                                 0.3
gcsfs                                  2023.5.0
gensim                                 4.3.1
geographiclib                          2.0
Geohash                                1.0
geojson                                3.0.1
geopandas                              0.13.0
geoplot                                0.5.1
geopy                                  2.3.0
geoviews                               1.9.6
ggplot                                 0.11.5
giddy                                  2.3.4
gitdb                                  4.0.10
GitPython                              3.1.31
google-api-core                        1.33.2
google-api-python-client               2.86.0
google-apitools                        0.5.31
google-auth                            2.17.3
google-auth-httplib2                   0.1.0
google-auth-oauthlib                   1.0.0
google-cloud-aiplatform                0.6.0a1
google-cloud-artifact-registry         1.8.1
google-cloud-automl                    1.0.1
google-cloud-bigquery                  2.34.4
google-cloud-bigtable                  1.7.3
google-cloud-core                      2.3.2
google-cloud-datastore                 2.15.2
google-cloud-dlp                       3.12.1
google-cloud-language                  2.6.1
google-cloud-monitoring                2.14.2
google-cloud-pubsub                    2.16.1
google-cloud-pubsublite                1.8.1
google-cloud-recommendations-ai        0.7.1
google-cloud-resource-manager          1.10.0
google-cloud-spanner                   3.33.0
google-cloud-storage                   1.44.0
google-cloud-translate                 3.8.4
google-cloud-videointelligence         2.8.3
google-cloud-vision                    2.8.0
google-crc32c                          1.5.0
google-pasta                           0.2.0
google-resumable-media                 2.5.0
googleapis-common-protos               1.57.1
gplearn                                0.4.2
gpustat                                1.0.0
gpxpy                                  1.5.0
graphviz                               0.20.1
greenlet                               2.0.2
grpc-google-iam-v1                     0.12.6
grpcio                                 1.51.1
grpcio-status                          1.48.1
gviz-api                               1.10.0
gym                                    0.26.2
gym-notices                            0.0.8
Gymnasium                              0.26.3
gymnasium-notices                      0.0.1
h11                                    0.14.0
h2o                                    3.40.0.4
h5py                                   3.8.0
haversine                              2.8.0
hdfs                                   2.7.0
hep-ml                                 0.7.2
hijri-converter                        2.3.1
hmmlearn                               0.3.0
holidays                               0.24
holoviews                              1.16.0
hpsklearn                              0.1.0
html5lib                               1.1
htmlmin                                0.1.12
httplib2                               0.21.0
httptools                              0.5.0
huggingface-hub                        0.14.1
humanize                               4.6.0
hunspell                               0.5.5
husl                                   4.0.3
hydra-slayer                           0.4.1
hyperopt                               0.2.7
hypertools                             0.8.0
ibis-framework                         5.1.0
idna                                   3.4
igraph                                 0.10.4
imagecodecs                            2023.3.16
ImageHash                              4.3.1
imageio                                2.28.1
imbalanced-learn                       0.10.1
imgaug                                 0.4.0
implicit                               0.5.2
importlib-metadata                     5.2.0
importlib-resources                    5.12.0
inequality                             1.0.0
ipydatawidgets                         4.3.3
ipykernel                              6.23.0
ipyleaflet                             0.17.2
ipympl                                 0.7.0
ipython                                8.13.2
ipython-genutils                       0.2.0
ipython-sql                            0.5.0
ipyvolume                              0.6.1
ipyvue                                 1.9.0
ipyvuetify                             1.8.10
ipywebrtc                              0.6.0
ipywidgets                             7.7.1
isoduration                            20.11.0
isort                                  5.12.0
isoweek                                1.3.3
itsdangerous                           2.1.2
Janome                                 0.4.2
jaraco.classes                         3.2.3
jax                                    0.4.10
jaxlib                                 0.4.7+cuda11.cudnn86
jedi                                   0.18.2
jeepney                                0.8.0
jieba                                  0.42.1
Jinja2                                 3.1.2
jmespath                               1.0.1
joblib                                 1.2.0
json5                                  0.9.11
jsonpatch                              1.32
jsonpointer                            2.0
jsonschema                             4.17.3
jupyter_client                         7.4.9
jupyter-console                        6.6.3
jupyter_core                           5.3.0
jupyter-events                         0.6.3
jupyter-http-over-ws                   0.0.8
jupyter-lsp                            1.5.1
jupyter_server                         2.5.0
jupyter_server_fileid                  0.9.0
jupyter-server-mathjax                 0.2.6
jupyter_server_proxy                   4.0.0
jupyter_server_terminals               0.4.4
jupyter_server_ydoc                    0.8.0
jupyter-ydoc                           0.2.4
jupyterlab                             3.6.3
jupyterlab-git                         0.41.0
jupyterlab-lsp                         4.1.0
jupyterlab-pygments                    0.2.2
jupyterlab_server                      2.22.1
jupyterlab-widgets                     3.0.7
jupytext                               1.14.5
kaggle                                 1.5.13
kaggle-environments                    1.12.0
keras                                  2.12.0
keras-tuner                            1.3.5
keyring                                23.13.1
keyrings.google-artifactregistry-auth  1.1.2
kfp                                    1.8.21
kfp-pipeline-spec                      0.1.16
kfp-server-api                         1.8.5
kiwisolver                             1.4.4
kmapper                                2.0.1
kmodes                                 0.12.2
korean-lunar-calendar                  0.3.1
kornia                                 0.6.12
kt-legacy                              1.0.5
kubernetes                             25.3.0
langcodes                              3.3.0
langid                                 1.1.6
lazy_loader                            0.2
lazy-object-proxy                      1.9.0
learntools                             0.3.4
leven                                  1.0.4
Levenshtein                            0.21.0
libclang                               16.0.0
libmambapy                             1.4.2
libpysal                               4.7.0
librosa                                0.10.0.post2
lightgbm                               3.3.2
lightning-utilities                    0.8.0
lime                                   0.2.0.1
line-profiler                          4.0.3
llvmlite                               0.39.1
lml                                    0.1.0
locket                                 1.0.0
LunarCalendar                          0.0.9
lxml                                   4.9.2
lz4                                    4.3.2
Mako                                   1.2.4
mamba                                  1.4.2
mapclassify                            2.5.0
marisa-trie                            0.8.0
Markdown                               3.4.3
markdown-it-py                         2.2.0
markovify                              0.9.4
MarkupSafe                             2.1.2
marshmallow                            3.19.0
marshmallow-enum                       1.5.1
matplotlib                             3.6.3
matplotlib-inline                      0.1.6
matplotlib-venn                        0.11.9
mccabe                                 0.7.0
mdit-py-plugins                        0.3.5
mdurl                                  0.1.2
memory-profiler                        0.61.0
mercantile                             1.2.1
mgwr                                   2.1.2
missingno                              0.5.2
mistune                                0.8.4
mizani                                 0.9.1
ml-dtypes                              0.1.0
mlcrate                                0.2.0
mlens                                  0.2.3
mlxtend                                0.22.0
mmh3                                   4.0.0
mne                                    1.4.0
mnist                                  0.2.2
mock                                   5.0.2
momepy                                 0.6.0
more-itertools                         9.1.0
mpld3                                  0.5.9
mpmath                                 1.3.0
msgpack                                1.0.5
msgpack-numpy                          0.4.8
multidict                              6.0.4
multimethod                            1.9.1
multipledispatch                       0.6.0
multiprocess                           0.70.14
munch                                  3.0.0
munkres                                1.1.4
murmurhash                             1.0.9
mypy-extensions                        1.0.0
nb-conda                               2.2.1
nb-conda-kernels                       2.3.1
nbclassic                              1.0.0
nbclient                               0.5.13
nbconvert                              6.4.5
nbdime                                 3.2.0
nbformat                               5.8.0
nest-asyncio                           1.5.6
netCDF4                                1.6.3
networkx                               3.1
nibabel                                5.1.0
nilearn                                0.10.1
ninja                                  1.11.1
nltk                                   3.2.4
nose                                   1.3.7
notebook                               6.5.4
notebook-executor                      0.2
notebook_shim                          0.2.3
numba                                  0.56.4
numexpr                                2.8.4
numpy                                  1.23.5
nvidia-ml-py                           11.495.46
nvtx                                   0.2.5
oauth2client                           4.1.3
oauthlib                               3.2.2
objsize                                0.6.1
odfpy                                  1.4.1
olefile                                0.46
onnx                                   1.14.0
opencensus                             0.11.2
opencensus-context                     0.1.3
opencv-contrib-python                  4.5.4.60
opencv-python                          4.5.4.60
opencv-python-headless                 4.5.4.60
openpyxl                               3.1.2
openslide-python                       1.2.0
opentelemetry-api                      1.17.0
opentelemetry-exporter-otlp            1.17.0
opentelemetry-exporter-otlp-proto-grpc 1.17.0
opentelemetry-exporter-otlp-proto-http 1.17.0
opentelemetry-proto                    1.17.0
opentelemetry-sdk                      1.17.0
opentelemetry-semantic-conventions     0.38b0
opt-einsum                             3.3.0
optax                                  0.1.5
optuna                                 3.1.1
orbax-checkpoint                       0.2.3
orderedmultidict                       1.0.1
orjson                                 3.8.12
ortools                                9.4.1874
osmnx                                  1.1.1
overrides                              6.5.0
packaging                              21.3
pandas                                 1.5.3
pandas-datareader                      0.10.0
pandas-profiling                       3.6.6
pandas-summary                         0.2.0
pandasql                               0.7.3
pandocfilters                          1.5.0
panel                                  0.14.4
papermill                              2.4.0
param                                  1.13.0
parso                                  0.8.3
parsy                                  2.1
partd                                  1.4.0
path                                   16.6.0
path.py                                12.5.0
pathos                                 0.3.0
pathtools                              0.1.2
pathy                                  0.10.1
patsy                                  0.5.3
pdf2image                              1.16.3
pexpect                                4.8.0
phik                                   0.12.3
pickleshare                            0.7.5
Pillow                                 9.5.0
pip                                    23.1.2
pkgutil_resolve_name                   1.3.10
platformdirs                           3.5.0
plotly                                 5.14.1
plotly-express                         0.4.1
plotnine                               0.10.1
pluggy                                 1.0.0
pointpats                              2.3.0
polars                                 0.17.15
polyglot                               16.7.4
pooch                                  1.6.0
pox                                    0.3.2
ppca                                   0.0.4
ppft                                   1.7.6.6
preprocessing                          0.1.13
preshed                                3.0.8
prettytable                            3.7.0
progressbar2                           4.2.0
prometheus-client                      0.16.0
promise                                2.3
prompt-toolkit                         3.0.38
pronouncing                            0.2.0
prophet                                1.1.1
proto-plus                             1.22.2
protobuf                               3.20.3
psutil                                 5.9.3
ptxcompiler                            0.8.1
ptyprocess                             0.7.0
pudb                                   2022.1.3
PuLP                                   2.7.0
pure-eval                              0.2.2
py-cpuinfo                             9.0.0
py-lz4framed                           0.14.0
py-spy                                 0.3.14
py4j                                   0.10.9.7
pyaml                                  23.5.9
PyArabic                               0.6.15
pyarrow                                10.0.1
pyasn1                                 0.4.8
pyasn1-modules                         0.2.7
PyAstronomy                            0.19.0
pybind11                               2.10.4
pyclipper                              1.3.0.post4
pycodestyle                            2.10.0
pycolmap                               0.4.0
pycosat                                0.6.4
pycparser                              2.21
pycryptodome                           3.18.0
pyct                                   0.5.0
pycuda                                 2022.2.2
pydantic                               1.10.7
pydegensac                             0.1.2
pydicom                                2.3.1
pydocstyle                             6.3.0
pydot                                  1.4.2
pydub                                  0.25.1
pyemd                                  1.0.0
pyerfa                                 2.0.0.3
pyexcel-io                             0.6.6
pyexcel-ods                            0.6.0
pyfasttext                             0.4.6
pyflakes                               3.0.1
pygltflib                              1.15.6
Pygments                               2.15.1
PyJWT                                  2.6.0
pykalman                               0.9.5
pyLDAvis                               3.2.2
pylibraft                              23.4.1
pylint                                 2.17.4
pymc3                                  3.11.5
PyMeeus                                0.5.12
pymongo                                3.13.0
Pympler                                1.0.1
pynndescent                            0.5.10
pynvml                                 11.4.1
pynvrtc                                9.2
pyocr                                  0.8.3
pyOpenSSL                              23.1.1
pyparsing                              3.0.9
pypdf                                  3.9.0
pyproj                                 3.5.0
pyrsistent                             0.19.3
pysal                                  23.1
pyshp                                  2.3.1
PySocks                                1.7.1
pytesseract                            0.3.10
python-bidi                            0.4.2
python-dateutil                        2.8.2
python-dotenv                          1.0.0
python-igraph                          0.10.4
python-json-logger                     2.0.7
python-Levenshtein                     0.21.0
python-louvain                         0.16
python-lsp-jsonrpc                     1.0.0
python-lsp-server                      1.7.3
python-slugify                         8.0.1
python-utils                           3.5.2
pythreejs                              2.4.2
pytoolconfig                           1.2.5
pytools                                2022.1.14
pytorch-ignite                         0.4.12
pytorch-lightning                      2.0.2
pytz                                   2023.3
pyu2f                                  0.1.5
PyUpSet                                0.1.1.post7
pyviz-comms                            2.2.1
PyWavelets                             1.4.1
PyYAML                                 5.4.1
pyzmq                                  25.0.2
qgrid                                  1.3.1
qtconsole                              5.4.3
QtPy                                   2.3.1
quantecon                              0.7.0
quantities                             0.14.1
qudida                                 0.0.4
raft-dask                              23.4.1
randomgen                              1.23.1
rapidfuzz                              3.0.0
rasterio                               1.3.7
rasterstats                            0.18.0
ray                                    2.4.0
ray-cpp                                2.4.0
regex                                  2023.5.5
requests                               2.28.2
requests-oauthlib                      1.3.1
requests-toolbelt                      0.10.1
responses                              0.18.0
retrying                               1.3.3
rfc3339-validator                      0.1.4
rfc3986-validator                      0.1.1
rgf-python                             3.12.0
rich                                   12.6.0
rmm                                    23.4.1
rope                                   1.8.0
rsa                                    4.9
Rtree                                  1.0.1
ruamel.yaml                            0.17.24
ruamel.yaml.clib                       0.2.7
ruamel-yaml-conda                      0.15.100
s2sphere                               0.2.5
s3fs                                   2023.5.0
s3transfer                             0.6.1
safetensors                            0.3.1
scattertext                            0.1.19
scikit-image                           0.20.0
scikit-learn                           1.2.2
scikit-learn-intelex                   2023.1.1
scikit-multilearn                      0.2.0
scikit-optimize                        0.9.0
scikit-plot                            0.3.7
scikit-surprise                        1.1.3
scipy                                  1.10.1
seaborn                                0.12.2
SecretStorage                          3.3.3
segment-anything                       1.0
segregation                            2.4.2
semver                                 3.0.0
Send2Trash                             1.8.2
sentencepiece                          0.1.99
sentry-sdk                             1.24.0
setproctitle                           1.3.2
setuptools                             59.8.0
setuptools-git                         1.2
setuptools-scm                         7.1.0
shap                                   0.41.0
Shapely                                1.8.5.post1
shellingham                            1.5.1
simpervisor                            0.4
SimpleITK                              2.2.1
simplejson                             3.19.1
six                                    1.16.0
sklearn-pandas                         2.2.0
slicer                                 0.0.7
smart-open                             6.3.0
smhasher                               0.150.1
smmap                                  5.0.0
sniffio                                1.3.0
snowballstemmer                        2.2.0
snuggs                                 1.4.7
sortedcontainers                       2.4.0
soundfile                              0.12.1
soupsieve                              2.3.2.post1
soxr                                   0.3.5
spacy                                  3.5.3
spacy-legacy                           3.0.12
spacy-loggers                          1.0.4
spaghetti                              1.7.2
spectral                               0.23.1
spglm                                  1.0.8
sphinx-rtd-theme                       0.2.4
spint                                  1.0.7
splot                                  1.1.5.post1
spopt                                  0.5.0
spreg                                  1.3.2
spvcm                                  0.3.0
SQLAlchemy                             2.0.12
sqlglot                                11.7.1
sqlparse                               0.4.4
squarify                               0.4.3
srsly                                  2.4.6
stack-data                             0.6.2
starlette                              0.26.1
statsmodels                            0.13.5
stemming                               1.0.1
stop-words                             2018.7.23
stopit                                 1.1.2
strip-hints                            0.1.10
stumpy                                 1.11.1
sympy                                  1.12
tables                                 3.8.0
tabulate                               0.9.0
tangled-up-in-unicode                  0.2.0
tbb                                    2021.9.0
tblib                                  1.7.0
tenacity                               8.2.2
tensorboard                            2.12.3
tensorboard-data-server                0.7.0
tensorboard-plugin-profile             2.11.2
tensorboardX                           2.6
tensorflow                             2.12.0
tensorflow-addons                      0.20.0
tensorflow-cloud                       0.1.16
tensorflow-datasets                    4.9.2
tensorflow-decision-forests            1.3.0
tensorflow-estimator                   2.12.0
tensorflow-gcs-config                  2.12.0
tensorflow-hub                         0.12.0
tensorflow-io                          0.31.0
tensorflow-io-gcs-filesystem           0.31.0
tensorflow-metadata                    0.14.0
tensorflow-probability                 0.20.0
tensorflow-serving-api                 2.12.1
tensorflow-text                        2.12.1
tensorflow-transform                   0.14.0
tensorflowjs                           3.15.0
tensorpack                             0.11
tensorstore                            0.1.36
termcolor                              2.3.0
terminado                              0.17.1
testpath                               0.6.0
text-unidecode                         1.3
textblob                               0.17.1
texttable                              1.6.7
textwrap3                              0.9.2
Theano                                 1.0.5
Theano-PyMC                            1.1.2
thinc                                  8.1.10
threadpoolctl                          3.1.0
tifffile                               2023.4.12
timm                                   0.9.2
tinycss2                               1.2.1
tobler                                 0.10
tokenizers                             0.13.3
toml                                   0.10.2
tomli                                  2.0.1
tomlkit                                0.11.8
toolz                                  0.12.0
torch                                  2.0.0
torchaudio                             2.0.1
torchdata                              0.6.0
torchinfo                              1.8.0
torchmetrics                           0.11.4
torchtext                              0.15.1
torchvision                            0.15.1
tornado                                6.3.1
TPOT                                   0.11.7
tqdm                                   4.64.1
traceml                                1.0.8
traitlets                              5.9.0
traittypes                             0.2.1
transformers                           4.29.2
treelite                               3.2.0
treelite-runtime                       3.2.0
trueskill                              0.4.5
tsfresh                                0.20.0
typeguard                              2.13.3
typer                                  0.7.0
typing_extensions                      4.5.0
typing-inspect                         0.8.0
tzlocal                                5.0.1
ucx-py                                 0.31.0
ujson                                  5.7.0
umap-learn                             0.5.3
unicodedata2                           15.0.0
Unidecode                              1.3.6
update-checker                         0.18.0
uri-template                           1.2.0
uritemplate                            3.0.1
urllib3                                1.26.15
urwid                                  2.1.2
urwid-readline                         0.13
uvicorn                                0.22.0
uvloop                                 0.17.0
vaex                                   4.16.0
vaex-astro                             0.9.3
vaex-core                              4.16.1
vaex-hdf5                              0.14.1
vaex-jupyter                           0.8.1
vaex-ml                                0.18.1
vaex-server                            0.8.1
vaex-viz                               0.5.4
vecstack                               0.4.0
virtualenv                             20.21.0
visions                                0.7.5
vowpalwabbit                           9.8.0
vtk                                    9.2.6
Wand                                   0.6.11
wandb                                  0.15.3
wasabi                                 1.1.1
watchfiles                             0.19.0
wavio                                  0.0.7
wcwidth                                0.2.6
webcolors                              1.13
webencodings                           0.5.1
websocket-client                       1.5.1
websockets                             11.0.3
Werkzeug                               2.3.4
wfdb                                   4.1.1
whatthepatch                           1.0.5
wheel                                  0.40.0
widgetsnbextension                     3.6.4
witwidget                              1.8.1
woodwork                               0.23.0
Wordbatch                              1.4.9
wordcloud                              1.9.2
wordsegment                            1.3.1
wrapt                                  1.14.1
wurlitzer                              3.0.3
xarray                                 2023.5.0
xarray-einstats                        0.5.1
xgboost                                1.7.5
xvfbwrapper                            0.2.9
xxhash                                 3.2.0
xyzservices                            2023.5.0
y-py                                   0.5.9
yapf                                   0.33.0
yarl                                   1.9.1
ydata-profiling                        4.1.2
yellowbrick                            1.5
ypy-websocket                          0.8.2
zict                                   3.0.0
zipp                                   3.15.0
zstandard                              0.19.0

猜你喜欢

转载自blog.csdn.net/weixin_53610475/article/details/131068797