torch memory analysis - clear the memory for the generated model

torch memory analysis - clear the memory for the generated model

1. Introduction to the problem

This article mainly focuses on how to quickly and easily clear the video memory occupied by the current process in the generation scenario. The focus of the article is not only the management of video memory, but also how to flexibly use custom components to control the generation process.

In the previous article torch video memory analysis - how to release the video memory without closing the process , through an experiment, the video memory usage of torch was analyzed, and how to use the code to release the video memory without closing the process. However, in recent experiments, it was found that the memory release method introduced before is not easy to use for generative models.

In the previous article, the method used was:

real_inputs = inputs['input_ids'][..., : 2, ...].to(model.device)
with torch.no_grad():
    logits = model(real_inputs, tail)
del real_inputs
del logits
torch.cuda.empty_cache()

However, if you directly replace the forward of the model with generate for the generated model, that is, the following replacement method, you will encounter problems.

with torch.no_grad():
	logits = model.generate(real_inputs)
del real_inputs
del logits
torch.cuda.empty_cache()

Because during the generation process, new tokens will be generated, and model.generate is likely to call forward more than once, so this method will not work.

2. Countermeasures

Since it is the forward method that simulates one side of the model, find a way to make the forward method only called once. Perhaps using model.forward directly can solve this problem. But here I have adopted another method - using Stopping Criteria.

Since you only want it to be generated and executed once, you can use a default criteria directly:

from transformers.generation.stopping_criteria import MaxNewTokensCriteria, StoppingCriteriaList

empty_cache_helper = StoppingCriteriaList()
empty_cache_helper.append(MaxNewTokensCriteria(start_length=0, max_new_tokens=1))

The function of this thing is to generate at most one new token, and then stop generating immediately.

Then when clearing the video memory, just add it:

with torch.no_grad():
	logits = model.generate(real_inputs, stopping_criteria=self.empty_cache_helper)
del real_inputs
del logits
torch.cuda.empty_cache()

If you don't know the stopping criteria, you can review the previous two articles:

Taking beam search as an example, explain the generate method in transformers in detail (above)
Take beam search as an example, explain the generate method in transformers in detail (below)

In future blogs, some examples may be combined to introduce the use of custom logits processor and stopping criteria. Interested students can pay attention to it.

Guess you like

Origin blog.csdn.net/weixin_44826203/article/details/132067916