添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

An intuitive introduction to text embeddings

Text embeddings are key to LLMs and convert text into vector coordinates.

Article hero image

Intuition is very important to understanding a concept. An intuitive grasp of a tool or concept means you can zoom out to the level of abstraction where you get the whole picture in view. I’ve spent the last four years building and deploying machine learning tools at AI startups. In that time, the technology has exploded in popularity, particularly in my area of specialization, natural language processing (NLP).

At a startup, I don’t often have the luxury of spending months on research and testing—if I do, it’s a bet that makes or breaks the product.

A sharp intuition for how a model will perform—where it will excel and where it will fall down—is essential for thinking through how it can be integrated into a successful product. With the right UX around it, even an imperfect model feels magical. Built wrong, the rare miss produced by even the most rock-solid system looks like a disaster.

A lot of my sense for this comes from the thousands of hours I’ve spent working with these models, seeing where they fall short and where they surprise me with their successes. But if there’s one concept that most informs my intuitions, it’s text embeddings. The ability to take a chunk of text and turn it into a vector, subject to the laws of mathematics, is fundamental to natural language processing. A good grasp of text embeddings will greatly improve your capacity to reason intuitively about how NLP (and a lot of other ML models) should best fit into your product.

So let’s stop for a moment to appreciate text embeddings.

OpenAI text-embedding-ada-002 model does with the paragraph above. Each vertical band in this plot represents a value in one of the embedding space’s 1536 dimensions.

Mathematically, an embedding space, or latent space, is defined as a manifold in which similar items are positioned closer to one another than less similar items. In this case, sentences that are semantically similar should have similar embedded vectors and thus be closer together in the space.

We can frame a lot of useful tasks in terms of text similarity.

  • Search : How similar is a query to a document in your database?
  • Spam filtering : How close is an email to examples of spam?
  • Content moderation : How close is a social media message to known examples of abuse?
  • Conversational agent : Which examples of known intents are closest to the user’s message?

In these cases, you can pre-calculate the embeddings for your targets (i.e. the documents you want to search or examples for classification) and store them in an indexed database. This lets you capture the powerful natural language understanding of deep neural models as text embeddings as you add new items to your database, then run your search or classifier without expensive GPU compute.

This direct comparison of text similarity is just one application for text embeddings. Often, embeddings have a place in ML algorithms or neural architectures with further task-specific components built on top. I’ve largely elided these details in the discussion below.

bag-of-words model . In a computational world, we can scale these dimensions up more smoothly than we could in the case of brick-and-mortar libraries, but the problem is similar in principle. Things just get unwieldy at these high dimensions. Algorithms grind to a halt as the combinatorics explode, and the sparsity (most documents will have a count of 0 for most terms) is problematic for statistics and machine learning.

What if we can identify some common semantic sense to words like “cat” and “feline?” We could spare our dimensionality budget and make our shelving scheme more intuitive.

And what about terms like “pet” or “mammal?” We can let these contribute to both cat- axis and dog -axis of a book they appear in. And if we lost something in collapsing the distinction between “cat” and “feline,” perhaps letting the latter contribute to a “scientific” latent term would recover it.

All we need, then, to project a book into our latent space is a big matrix that defines how much each of the observed terms in our vocabulary contributes to each of our latent terms.

Latent semantic analysis (LSA), which uses the singular value decomposition of the term-document matrix (fancy linear algebra, basically), and Latent Dirichlet allocation (LDA), which uses a statistical method called the Dirichlet process.

LDA and LSA are still widely used for topic modeling. You can often find them as “read next” links in an article’s footer. But they’re limited to capturing a broad sense of topicality in a document. The models rely on document inputs being long enough to have a representative sample of words. And with the unordered bag-of-words input, there’s no way to capture proximity of words, let alone complex syntax and semantics.

Word2vec , which uses windowed text sampling to create embeddings for individual words. A sliding window moves through text in the training data, one word at a time. For each position of the window, Word2vec creates a context set. For example, with a window size of 3 in the sentence "the cat sat on the mat", (‘the’, ‘cat’, ‘sat’) are grouped together, just like a set of library books a reader had checked out in the example above. During training, this pushes vectors for 'the', 'cat', and 'sat' all a little closer in the latent space.

A key point here is that we don’t need to spend much time on training data for this model—it uses a large corpus of raw text as-is, and can extract some surprisingly detailed insights about language.

These word embeddings show the power of vector arithmetic. The famous example is the equation king - man + woman ≈ queen. The vector for 'king', minus the vector for 'man' and plus the vector for 'woman', is very close to the vector for 'queen'. A relatively simple model, given a large enough training corpus, can give us a surprisingly rich latent space.

recurrent neural network (RNN) is that each token (usually a word or word piece) in our sequence feeds forward into the representation of our next one. We start with the embedding for our first token t0. For the next token, t1 we take some function (defined by the weights our neural network learns) of the embeddings for t0 and t1 like f (t0, t1). Each new token combines with the previous token in the sequence until we reach the final token, whose embedding is used to represent the whole sequence. This simple version of this architecture is a fully-recurrent neural network (FRNN).

This architecture has issues with vanishing gradients that limit the neural network training process. Remember, training a neural network works by making small updates to model parameters based on a loss function that expresses how close the model’s prediction for a training item is to the true value. If an early parameter is buried under a series of decimal weights later in the model, it quickly approaches zero. Its impact on the loss function becomes negligible, as do any updates to its value.

This is a big problem for long-distance relationships common in text. Consider the sentence "The dog that I adopted from the pound five years ago won the local pet competition." It's important to understand that it's the dog that won the competition despite the fact that none of these words are adjacent in the sequence.

long short-term memory (LSTM) architecture addresses this vanishing gradient problem. The LSTM uses a long-term memory cell that stably passes information forward parallel to the RNN, while a set of gates passes information in and out of the memory cell.

Remember, though, that in the machine learning world a larger training set is almost always better. The fact that the LSTM has to calculate a value for each token sequentially before it can start on the next is a big bottleneck—it’s impossible to parallelize these operations.

transformer architecture , which is at the heart of the current generation of LLMs, is an evolution of the LSTM concept. Not only does it better capture the context and dependencies between words in a sequence, but it can run in parallel on the GPU with highly-optimized tensor operations.

The transformer uses an attention mechanism to weigh the influence of each token in the sequence on each other token. Along with an embedding value of each token, the attention mechanism learns two more vectors for each token: a query vector and a key vector. How close a token’s query vector is to another token’s key vector determines how much of the second token’s value gets added to the first.

Because we’ve loosened up the sequence bottleneck, we can afford to stack up multiple layers of attention—at each layer, the attention contributes a little meaning to each token from the others in the sequence before moving on to the next layer with the updated values.

If you’ve followed enough so far that we can cobble together a spatial intuition for this attention mechanism, I’ll consider this article a success. Let’s give it a try.

A token’s value vector captures its semantic meaning in a high-dimensional embedding space, much like in our library analogy from earlier. The attention mechanism uses another embedding space for the key and query vectors—a sort of semantic plumbing in the floor between each level of the library. The key vector positions the output end of a pipe that draws some semantic value from the token and pumps it out into the embedding space. The query vector places the input end of a pipe that sucks up semantic value other tokens’ key vectors pump into the embedding space nearby and all this into the token’s new representation on the floor above.

To capture an embedding for a full sequence, we just pick one of these tokens to grab a value vector from and use in the downstream tasks. (Exactly which token this is depends on the specific model. Masked models like BERT use a special [CLS] or [MASK] token, while the autoregressive GPT models use the last token in the sequence.)

So the transformer architecture can encode sequences really well, but if we want it to understand language well, how do we train it? Remember, when we start training, all these vectors are randomly initialized. Our tokens’ value vectors are distributed at random in their semantic embedding space as are our key and query vectors in theirs. We ask the model to predict a token given the rest of the encoded sequence. The great thing about this task is that we can gather as much text as we can find and turn it into training data. All we have to do is hide one of the tokens in a chunk of text from the model and encode what’s left. We already know what the missing token should be, so we can build a loss function based on how close the prediction is to this known value.

The other beautiful thing is that the difficulty of predicting the right word scales up smoothly. It goes from a general sense of topicality and word order—something even a simple predictive text model on your phone can do pretty well—up through complex syntax and semantics.

The incredible thing here is that as we scale up the number of parameters in these models—things like the size of the embeddings and number of transformer layers—and scale up the size of the training data, the models just keep getting better and smarter.

Add to the discussion

Login with your stackoverflow.com account to take part in the discussion.