Tensor Sharing in Pytorch
Torch shared tensors
TL;DR
Using specific functions, which should work in most cases for you. This is not without side effects.
Copied
What are shared tensors ?
Pytorch uses shared tensors for some computation. This is extremely interesting to reduce memory usage in general.
One very classic use case is in transformers the embeddings
are shared with lm_head
. By using the same matrix, the model uses less parameters, and gradients flow much better to the embeddings
(which is the start of the model, so they don’t flow easily there, whereas lm_head
is at the tail of the model, so gradients are extremely good over there, since they are the same tensors, they both benefit)
Copied
Why are shared tensors not saved in safetensors ?
Multiple reasons for that:
Not all frameworks support them for instance
tensorflow
does not. So if someone saves shared tensors in torch, there is no way to load them in a similar fashion so we could not keep the sameDict[str, Tensor]
API.It makes lazy loading very quickly. Lazy loading is the ability to load only some tensors, or part of tensors for a given file. This is trivial to do without sharing tensors but with tensor sharing
Copied
Now it’s impossible with this given code to “reshare” buffers after the fact. Once we give the
a
tensor we have no way to give back the same memory when you ask forb
. (In this particular example we could keep track of given buffers but this is not the case in general, since you could do arbitrary work witha
like sending it to another device before asking forb
)It can lead to much larger file than necessary. If you are saving a shared tensor which is only a fraction of a larger tensor, then saving it with pytorch leads to saving the entire buffer instead of saving just what is needed.
Copied
Now with all those reasons being mentioned, nothing is set in stone in there. Shared tensors do not cause unsafety, or denial of service potential, so this decision could be revisited if current workarounds are not satisfactory.
How does it work ?
The design is rather simple. We’re going to look for all shared tensors, then looking for all tensors covering the entire buffer (there can be multiple such tensors). That gives us multiple names which can be saved, we simply choose the first one
During load_model
, we are loading a bit like load_state_dict
does, except we’re looking into the model itself, to check for shared buffers, and ignoring the “missed keys” which were actually covered by virtue of buffer sharing (they were properly loaded since there was a buffer that loaded under the hood). Every other error is raised as-is
Caveat: This means we’re dropping some keys within the file. meaning if you’re checking for the keys saved on disk, you will see some “missing tensors” or if you’re using load_state_dict
. Unless we start supporting shared tensors directly in the format there’s no real way around it.
Last updated