Retrieval augmented generation (RAG) with vectorized kernels#

This notebook demonstrates how to use LlamaIndex, a popular RAG framework to index our reference kernel database and use it to improve our LLM code generation capabilities.

Goals#

  • Index a vector database of vectorized kernels

  • Create a basic retriever

  • Generate new kernels solutions with retrieved kernels in-context

  • Compared RAG results with previous implementation

References#

# Import LlamaIndex essentials
from llama_index.core import (
    StorageContext,
    load_index_from_storage,
    VectorStoreIndex,
    SimpleDirectoryReader
)

Important: Set your OpenAI API key as an environment variable before running this notebook.
Example:

import os
os.environ['OPENAI_API_KEY'] = 'sk-...your-key-here...'

Indexing the kernel database#

We’ve currated a handful of open-source kernels and stored them in rag/kernels of this repo - the contained C++ sources will be used to create our vector index.

import os

# Where the vector store will live
PERSIST_DIR = "./rag/vector_database"

# If the database already exists let's not re-index everything
if not os.path.exists(PERSIST_DIR):
    print("Indexing...")
    documents = SimpleDirectoryReader("../rag/kernels").load_data()
    index = VectorStoreIndex.from_documents(documents)
    index.storage_context.persist(persist_dir=PERSIST_DIR)
else:
    storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
    index = load_index_from_storage(storage_context)
Indexing...

Retrieving Similar Kernels#

Let’s test this out by retrieving some nodes for a given prompt.

retriever = index.as_retriever(similarity_top_k=2)
nodes = retriever.retrieve("A ReLU kernel")

for node in nodes:
    print(node.metadata['file_name'])
    print(node.score)
relu.cc
0.8285732809964754
conv2dk1_i8.cc
0.7933807263913192

Great! The top result is a ReLU kernel which is the ideal choice for this prompt.

Using a Real Kernel Prompt#

Now, let’s use a real NPUEval prompt and see how our retriever does.

from npueval import dataset
sample = dataset.get_by_name("add_offset_int8")

nodes = retriever.retrieve(sample['prompt'])

for node in nodes:
    print(node.metadata['file_name'])
    print(node.score)
plus1.cpp
0.87517622154528
reduce_add.cc
0.8658736014149861

plus1 is actually a good choice for an add_offset kernel because functionally they are very similar. However add_offset needs to add a runtime parameter for the offset instead of having a hardcoded constant that gets added to the input vector.

Add context to the prompt#

Now we can craft a combined prompt with the additional context of the retrieved kernel:

context_string = "Reference vectorized code:\n"
for node in nodes:
    context_string += node.node.text

prompt_with_context = sample['prompt'] + "\n" + context_string

print(prompt_with_context[:2000]) # truncated output
/*
This AIE kernel adds a scalar int8 offset to every element of the input int8_t vector (length 256), and writes the result to the output buffer.
>>> add_offset_int8([72, -53, 17, 92, -33, 95, 3, -91], -11)
[61, -64, 6, 81, -44, 84, -8, -102]
This kernel should be optimized for the following input/output buffer shapes and parameters:
in_buffer size: 256
out_buffer size: 256
offset: -11
*/
#include <aie_api/aie.hpp>
#include "aie_kernel_utils.h"

void add_offset_int8(int8_t *in_buffer, int8_t *out_buffer, int8_t offset) {
    // Implementation goes here
}

Reference vectorized code:
// Copyright 2023 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: MIT

#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <aie_api/aie.hpp>


void plusone_aie(uint8_t *in_buffer, uint8_t* out_buffer, uint32_t nbytes) {
    ::aie::vector<uint8_t, 32> buffer;
    ::aie::vector<uint8_t, 32> inverted_buffer;
    uint16_t loop_count = (nbytes) >> 5;
    for(int j=0; j<loop_count; j++) {
        buffer = ::aie::load_v<32>(in_buffer);
        inverted_buffer = ::aie::add((uint8_t)1, buffer);
        in_buffer += 32;
        ::aie::store_v((uint8_t*)out_buffer, inverted_buffer);
        out_buffer += 32;
    }
}

extern "C" {

void plusone(uint8_t *in_buffer, uint8_t* out_buffer, uint32_t nbytes) {
    plusone_aie(in_buffer, out_buffer, nbytes);
}

}//===- reduce_add.cc --------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Copyright (C) 2023-2025, Advanced Micro Devices, Inc.
//
//===----------------------------------------------------------------------===//

#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <type_traits>

#include "aie_kernel_utils.h"
#include <aie_api/aie.hpp>

static void _reduce_add_vector(int32_t *restrict in, int32_t *res

Generate new kernel code#

We’ll generate a kernel with and without RAG and then compare how they did.

Without RAG#

As we’ve seen before GPT-4o-mini is very good at generating simple C++ code that will pass the functional tests. However this code is wholly unoptimized.

from npueval.aiecoder import AIECoder
coder = AIECoder(model='gpt-4o-mini', temperature=0.4, attempts=5)
response = coder(sample['prompt'])

print(f"Code after {response['attempt']} re-compilations:")
print(response['response'])
Code after 0 re-compilations:
```cpp
#include <aie_api/aie.hpp>
#include "aie_kernel_utils.h"

void add_offset_int8(int8_t *in_buffer, int8_t *out_buffer, int8_t offset) {
    constexpr int size = 256;
    // Load input data and add offset
    for (int i = 0; i < size; i++) {
        out_buffer[i] = in_buffer[i] + offset;
    }
}
```

With RAG#

Now we’ll reset our coder and generate again with a prompt containing an example of the plus1 kernel.

coder.reset_history()
response_with_rag = coder(prompt_with_context)
print(f"Code after {response_with_rag['attempt']} re-compilations:")
print(response_with_rag['response'])
Code after 3 re-compilations:
```cpp
#include <aie_api/aie.hpp>

void add_offset_int8(int8_t *in_buffer, int8_t *out_buffer, int8_t offset) {
    ::aie::vector<int8_t, 64> buffer; // Use a vector of size 64 for loading
    ::aie::vector<int8_t, 64> result;
    const int loop_count = 256 / 64; // 256 elements, 64 elements per vector
    ::aie::vector<int8_t, 64> offset_vector = ::aie::broadcast(offset); // Broadcast the offset

    for(int j = 0; j < loop_count; j++) {
        buffer = ::aie::load_v<64>(in_buffer); // Load 64 elements from input buffer
        result = buffer + offset_vector; // Add offset to each element
        ::aie::store_v(out_buffer, result); // Store the result back to output buffer
        in_buffer += 64; // Move to the next segment of input buffer
        out_buffer += 64; // Move to the next segment of output buffer
    }
}
```

Immediately we can tell that the output tells a much more compelling story of optimization. It uses AIE namespaces, API functions like ::aie::broadcast and buffer allocations using AIE vectors like ::aie::vector<int8_t, 64>. So how does this kernel match up against a vanilla GPT-4o-mini solution?

Comparing results#

from pathlib import Path
import json

# Save vanilla gpt-4o-mini result to json
solutions_path = Path("results/rag/gpt-4o-mini")
solutions_path.mkdir(parents=True, exist_ok=True)

result = {"code": coder.extract_codeblock(response['response'])}
solution_file = solutions_path / "add_offset_int8.json"
with solution_file.open('w') as file:
    json.dump(result, file, indent=4)

# Save RAG-enhanced gpt-4o-mini result to json
solutions_path_rag = Path("results/rag/gpt-4o-mini_rag")
solutions_path_rag.mkdir(parents=True, exist_ok=True)

result_rag = {"code": coder.extract_codeblock(response_with_rag['response'])}
solution_file = solutions_path_rag / "add_offset_int8.json"
with solution_file.open('w') as file:
    json.dump(result_rag, file, indent=4)

We’ll loop over only the 2 model solutions to get their performance metrics.

from npueval import run_functional_tests

for model in ['gpt-4o-mini', 'gpt-4o-mini_rag']:
    print(f"Evaluating {model}")
    run_functional_tests([sample], 
                     solutions=f"results/rag/{model}",
                     results_path=f"results/rag/evaluations/{model}",
                     overwrite=True)
Evaluating gpt-4o-mini

Kernel: add_offset_int8_wrapper
results/rag/evaluations/gpt-4o-mini/add_offset_int8_wrapper.mlir generated successfully
add_offset_int8_wrapper.xclbin, add_offset_int8_wrapper.bin built
Trace written to results/rag/evaluations/gpt-4o-mini/add_offset_int8_wrapper_trace.json
Result: Pass
Passed: 1/1
Evaluating gpt-4o-mini_rag

Kernel: add_offset_int8_wrapper
results/rag/evaluations/gpt-4o-mini_rag/add_offset_int8_wrapper.mlir generated successfully
add_offset_int8_wrapper.xclbin, add_offset_int8_wrapper.bin built
Trace written to results/rag/evaluations/gpt-4o-mini_rag/add_offset_int8_wrapper_trace.json
Result: Pass
Passed: 1/1

Both solutions pass the functional tests, but the RAG solution significantly reduces the number of cycles for the same workload! That’s an over 70x improvement in cycle count.

for model in ['gpt-4o-mini', 'gpt-4o-mini_rag']:
    print(model)
    print(len(model)*"-")
    with open(f"results/rag/evaluations/{model}/add_offset_int8_wrapper.json") as f:
        data = json.load(f)

    print(f"Result:       {data['result']}")
    print(f"Total cycles: {data['total_cycles']}")
    print(f"VPU cycles:   {data['vector_cycles']}")
    print(f"Vector score: {data['vector_score']*100}")
    print()
gpt-4o-mini
-----------
Result:       Pass
Total cycles: 2502
VPU cycles:   0
Vector score: 0.0

gpt-4o-mini_rag
---------------
Result:       Pass
Total cycles: 35
VPU cycles:   2
Vector score: 5.714285714285714

Note that even though the vectorized solution achieves a significant decrease in cycle count the vector score is still relatively low at 5% - this is because not everything can run on the VPU and inevitably some operations will require the scalar unit, hence achieving a perfect 100% is largely unachievable. It is a good metric to track however as a proxy for vectorization, since a score of 0 gives us a very clear flag that the VPU is being unused which is not what we want.


Copyright© 2025 AMD, Inc SPDX-License-Identifier: MIT