Feature extraction using LSH

Locality Sensitive Hashing (LSH) is often used for Approximate Nearest Neighbor (ANN) operations (vector search). The properties of LSH can also be exploited in neural network models that take vectors as input (e.g., various content signals such as audio, video, and text embeddings).

Often, the input manifolds in domain-specific models are complex (non-II-D). This complexity makes it very difficult to separate these manifolds using computationally intensive operations of multilayer perceptrons. The classic scheme for learning complex mappings is to memorize outcomes, rather than learning functions. How to memorize vector graphics? The most straightforward approach is embedding vectors. But we need discrete objects to compute embeddings, and vectors are not discrete. So how to apply the vector embedding algorithm to the vector input? Hash the vector, nearby points must remain "nearby" after hashing. This is what LSH does, so the embedding on top of my LSH operation can be used as a shallow feature extractor.

"Locality Sensitive Hashing" (LSH for short) is an approximate search technique for solving such problems. Its main idea is to map similar data points into the same "hash" bucket, so that a search can be done in a specific bucket instead of doing a linear search through the entire dataset. While this approach does not guarantee finding the exact nearest neighbors, it provides an efficient approximate search method in high-dimensional data.

The core concepts of LSH are as follows:

  1. Locality Sensitive Function : This is a function that is able to map similar data points into the same hash bucket, but it is not so strict, so even if some data points are mapped into the same bucket, They don't have to be truly similar, either. The design of the local sensitivity function depends on the type of data being processed and the similarity measure.
  2. Hash Bucket : Data points are mapped to different hash buckets through a local sensitivity function. Similar data points may be mapped to the same bucket, providing a starting point for the search.
  3. Hash Table : Hash buckets constitute a hash table, and by searching in the hash table, data points with similarities can be quickly located.

The performance of LSH depends on the design of the local sensitivity function and the construction of hash buckets. This involves mapping data points into distinct buckets while maintaining similarity, and organizing and retrieving data in a hash table. LSH is usually used to solve the approximate nearest neighbor search (Approximate Nearest Neighbor Search, ANN) problem, where the goal is to find a data point with a high similarity to a given query point.

The choice of LSH algorithm and the way of converting LSH buckets to embeddings is very important. So here is just a direction-aware algorithm (ignoring the size of the vector), which is based on this simple LSH algorithm:

 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
 
 class CosineVectorEmbedding(nn.Module):
     """
     LSH based vector indexer for highly non-linear ops
     """
 
     def __init__(self, inp_dim: int, emb_dim: int, n_proj: int = 16, num_bins: int = 20):
         super().__init__()
         self.register_buffer(
             'projection_mat',
             F.normalize(torch.randn((inp_dim, n_proj)), p=2.0, dim=0),
             persistent=True,
         )
         resolution = 2.0 / num_bins
         self.register_buffer(
             'grid',
             torch.linspace(-1, 1, num_bins + 1)[:-1] + 0.5 * resolution,
             persistent=True,
         )
         self.register_buffer(
             'pos_offset',
             ((num_bins + 1) * torch.arange(0, n_proj, dtype=torch.long)).long().reshape(-1, 1, 1),
             persistent=True
         )
         self.emb = nn.EmbeddingBag((num_bins + 1) * n_proj, emb_dim)
         self.emb_dim = emb_dim
         self.n_proj = n_proj
 
     def forward(self, x):
         bs, seq_len, emb_dim = x.size()
         z = F.normalize(x, p=2.0, dim=-1) @ self.projection_mat
         z = torch.bucketize(z, self.grid).transpose(0, -1)
         z = (z + self.pos_offset).transpose(0, -1).contiguous()
         return self.emb(z.view(-1, self.n_proj)).reshape(bs, seq_len, self.emb_dim)

To illustrate its effectiveness, we apply it to the training of the RecSys LLM fed a 32-dimensional input content embedding. Use independent cascaded LSH embeddings from low to high resolution (inp_dim=32, emb_dim=512, n_proj=32, num_bins=(1, 2, 4, 8, 12, 16, 20)) and output it add up. This is compared to using a simple projection (using nn.Linear(32, 512)).

It can be seen that our CosineVectorEmbedding is a better feature extractor than a simple linear transformation (of course, more parameters and higher computational efficiency).

https://avoid.overfit.cn/post/2bab364a679f4b6f8d9a1c0bd3096b9b

By Dinesh Ramasamy

Guess you like

Origin blog.csdn.net/m0_46510245/article/details/132256962