Xiaobai learns Pytorch series--Torch.optim API Base class(1)

Xiaobai learns Pytorch series – Torch.optim API Base class (1)


torch.optim is a package implementing various optimization algorithms. Most of the commonly used methods are already supported, and the interface is generic enough that more complex methods can be easily integrated in the future.

How to use the optimizer

Using torch.optim by hand you have to construct an optimizer object that will hold the current state and will update the parameters based on the computed gradients.

construct it

To construct an optimizer , you have to give it an iterable object containing parameters (all should be variables) to optimize. You can then specify optimizer-specific options such as learning rate, weight decay, etc.

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

Each parameter option

The optimizer also supports specifying options for each parameter. To do this, instead of passing an iterable of variables, pass an iterable of dicts. Each of them will define a separate parameter group and should contain a params key containing the list of parameters belonging to it. Other keys should match the keyword arguments accepted by the optimizer and will be used as optimization options for this group.

Note: You can still pass options as keyword arguments. In groups that do not override them, they will be used as defaults. This is useful when you want to vary only one option while keeping all other options consistent across parameter groups.

This is useful, for example, when wanting to specify the learning rate for each layer

optim.SGD([
                {
    
    'params': model.base.parameters()},
                {
    
    'params': model.classifier.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

This means that this model.baseparameter will use the default learning rate 1e-2, model.classifier’parameters will use 1e-3the learning rate of , and all parameters will use a momentum of 0.9.

Perform optimization steps

All optimizers implement a step()method for updating parameters. It has two uses.
optimizer.step()
This is a simplified version supported by most optimizers. This function can be called after the gradient calculation is complete, eg backward().

For example:

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

optimizer.step(closure)
Some optimization algorithms (such as conjugate gradient and LBFGS) require recomputing functions multiple times, so you must pass in a closure that allows them to recompute your model. The closure should clear the gradients, compute the loss and return.

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

Base class

This part refers to: https://zhuanlan.zhihu.com/p/87209990
PyTorch's optimizer basically inherits from "class Optimizer", which is the base class of all optimizers.
The following is the structure of the Optimizer

class Optimizer(object):
    def __init__(self, params, defaults):
        self.defaults = defaults
        self._hook_for_profile()
        if isinstance(params, torch.Tensor):
            raise TypeError("params argument given to the optimizer should be "
                            "an iterable of Tensors or dicts, but got " +
                            torch.typename(params))

        self.state = defaultdict(dict)
        self.param_groups = []

        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        if not isinstance(param_groups[0], dict):
            param_groups = [{
    
    'params': param_groups}]
        for param_group in param_groups:
            self.add_param_group(param_group)
    def state_dict(self):
    	...

    def load_state_dict(self, state_dict):
        ...

    def cast(param, value):
    	...

    def zero_grad(self, set_to_none: bool = False):
    	...
    
    def step(self, closure):
    	...
    
    def add_param_group(self, param_group):    
    	...

init function initialization

paramsand defaultsare two important parameters, defaults defines the global optimization default value, params defines the model parameters and local optimization default value.

add_param_group

defaultdictkey 被查找但不存在时,返回的不是keyError而是一个默认值,此处The function is that when defaultdict(dict)` returns in the dictionary, the default value will be an empty dictionary. The last line calls self.add_param_group(param_group), where param_group is a dictionary, Key is params, and Value is param_groups = list(params).

def add_param_group(self, param_group):
        params = param_group['params']
        if isinstance(params, torch.Tensor):
            param_group['params'] = [params]
        elif isinstance(params, set):
            raise TypeError('optimizer')
        else:
            param_group['params'] = list(params)

        for param in param_group['params']:
            if not isinstance(param, torch.Tensor):
                raise TypeError("optimizer " + torch.typename(param))
            if not param.is_leaf:
                raise ValueError("can't optimize a non-leaf Tensor")

        for name, default in self.defaults.items():
            if default is required and name not in param_group:
                raise ValueError("parameter group didn't specify a value of required optimization parameter " +
                                 name)
            else:
                param_group.setdefault(name, default) # 给参数设置默认参数

        params = param_group['params']
        if len(params) != len(set(params)):
            warnings.warn("optimizer contains ", stacklevel=3)

        param_set = set()
        for group in self.param_groups:
            param_set.update(set(group['params']))

        if not param_set.isdisjoint(set(param_group['params'])): # 判断两个集合是否包含相同的元素
            raise ValueError("some parameters appear in more than one parameter group")

        self.param_groups.append(param_group)

zero_grad

It is to set the gradient of all parameters to zero p.grad.zero_(). detach_()The role of Detaches the Tensor from the graph that created it, making it a leaf. self.param_groups is a list, the elements of which are dictionaries.

def zero_grad(self):
    r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
    for group in self.param_groups:
        for p in group['params']:
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

step

For the function of updating parameters, there is only one line of code in stepthe function raise NotImplementedError. The parameters of the network model and the parameters of the optimizer are stored in the element self.param_groupsof , which stores and accesses the specific parameters of the network model and the parameters of the optimizer in the form of a dictionary. Therefore, each parameter of the network model can be accessed through two layers of iteration p. After obtaining the gradient d_p = p.grad.data, adjust the parameters according to whether the optimizer parameter setting uses momentum or nesterov. p.data.add_(-group['lr'], d_p)The function of the last line is to update the parameters. The state is used to save this update, which is the number of iterations of the optimizer to update the parameters.

Let's take the SGD optimizer as an example

def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            maximize = group['maximize']
            lr = group['lr']

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    d_p_list.append(p.grad)

                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        momentum_buffer_list.append(None)
                    else:
                        momentum_buffer_list.append(state['momentum_buffer'])

            F.sgd(params_with_grad,
                  d_p_list,
                  momentum_buffer_list,
                  weight_decay=weight_decay,
                  momentum=momentum,
                  lr=lr,
                  dampening=dampening,
                  nesterov=nesterov,
                  maximize=maximize,)

            # update momentum_buffers in state
            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
                state = self.state[p] ## 保存
                state['momentum_buffer'] = momentum_buffer
        return loss

F.sgd

def sgd(params: List[Tensor],
        d_p_list: List[Tensor],
        momentum_buffer_list: List[Optional[Tensor]],
        *,
        weight_decay: float,
        momentum: float,
        lr: float,
        dampening: float,
        nesterov: bool,
        maximize: bool):
    for i, param in enumerate(params):
        d_p = d_p_list[i]
        if weight_decay != 0:
            d_p = d_p.add(param, alpha=weight_decay)
        if momentum != 0:
            buf = momentum_buffer_list[i]
            if buf is None:
                buf = torch.clone(d_p).detach()
                momentum_buffer_list[i] = buf
            else:
                buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
            if nesterov:
                d_p = d_p.add(buf, alpha=momentum)
            else:
                d_p = buf
        alpha = lr if maximize else -lr
        param.add_(d_p, alpha=alpha)

A Momentum (also known as Heavy Ball) improvement was introduced on SGD.

load_state_dict

Load optimizer state.

def load_state_dict(self, state_dict):

     # deepcopy, to be consistent with module API
     state_dict = deepcopy(state_dict)
     # Validate the state_dict
     groups = self.param_groups
     saved_groups = state_dict['param_groups']

     if len(groups) != len(saved_groups):
         raise ValueError("loaded state dict has a different number of "
                          "parameter groups")
     param_lens = (len(g['params']) for g in groups)
     saved_lens = (len(g['params']) for g in saved_groups)
     if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
         raise ValueError("loaded state dict contains a parameter group "
                          "that doesn't match the size of optimizer's group")

     # Update the state
     id_map = {
    
    old_id: p for old_id, p in
               zip(chain.from_iterable((g['params'] for g in saved_groups)),
                   chain.from_iterable((g['params'] for g in groups)))}

     def cast(param, value):
         r"""Make a deep copy of value, casting all tensors to device of param."""
         if isinstance(value, torch.Tensor):
             # Floating-point types are a bit special here. They are the only ones
             # that are assumed to always match the type of params.
             if param.is_floating_point():
                 value = value.to(param.dtype)
             value = value.to(param.device)
             return value
         elif isinstance(value, dict):
             return {
    
    k: cast(param, v) for k, v in value.items()}
         elif isinstance(value, container_abcs.Iterable):
             return type(value)(cast(param, v) for v in value)
         else:
             return value

     # Copy state assigned to params (and cast tensors to appropriate types).
     # State that is not assigned to params is copied as is (needed for
     # backward compatibility).
     state = defaultdict(dict)
     for k, v in state_dict['state'].items():
         if k in id_map:
             param = id_map[k]
             state[param] = cast(param, v)
         else:
             state[k] = v

state_dict

Returns the state of the optimizer as a dictionary.

def state_dict(self):    
     # Save order indices instead of Tensors
      param_mappings = {
    
    }
      start_index = 0

      def pack_group(group):
          nonlocal start_index
          packed = {
    
    k: v for k, v in group.items() if k != 'params'}
          param_mappings.update({
    
    id(p): i for i, p in enumerate(group['params'], start_index)
                                 if id(p) not in param_mappings})
          packed['params'] = [param_mappings[id(p)] for p in group['params']]
          start_index += len(packed['params'])
          return packed
      param_groups = [pack_group(g) for g in self.param_groups]
      # Remap state to use order indices as keys
      packed_state = {
    
    (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
                      for k, v in self.state.items()}
      return {
    
    
          'state': packed_state,
          'param_groups': param_groups,
      }

Guess you like

Origin blog.csdn.net/weixin_42486623/article/details/129917588