Can GRPO teach Qwen3.6-27B to forecast the PPA of RTL it is about to write?
We fine-tune Qwen3.6-27B (4-bit QLoRA) with GRPO. The reward comes from actually synthesizing each generated module with yosys + abc on sky130, then scoring it with our GNN cost-model surrogate for area & power, plus the abc-reported Fmax. Three reward generations:
| run | prompts | reward signal | LoRA target |
|---|---|---|---|
| v2 | 10 design families × 3 freq targets | throughput + freq-target + forecast-consistency | r=16 attention only |
| v3 | same + a required //+ BUDGET block | + budget completeness / internal-consistency / calibration | r=32 attn+MLP, layers 32-63 |
| v4 (running) | same as v3 | same, but Qwen3 <think> reasoning enabled | same as v2 base |
v3 forces the model to emit a structured budget block (per-block area / power / FF estimates that must sum to a TOTAL line). Completeness (did it fill the block?) and internal consistency (do the lines sum to the total?) stay rock-solid at ~0.93. But calibration — whether the forecast TOTAL actually matches the synthesized result — drifts downward from 0.42 to 0.36 across training.
Diagnosis: the model learned to fill the template, not to reason about the number. The throughput reward term (weight 0.30) is gradient-richer than the calibration term (0.20), so the policy buys throughput at the cost of forecast honesty. Median power forecast error held around 205 µW — fine for small modules (counter, FIFO) but the model systematically under-forecasts large designs (matmul, FIR) by 5–30 mW because most training prompts are small blocks.
This is what motivated v4: enable Qwen3's native <think>
reasoning so the model derives the number token-by-token before committing it, instead of one-shotting
a plausible-looking budget. That run is in flight on a parallel H100.
The acid test. Every compiled v3 sample gives a pair: what the model
forecast in its budget block (y-axis) vs what yosys+GNN actually
measured (x-axis). A model with real physical intuition would have its
points sitting on the green y=x diagonal.
Same matmul4x4 prompt. Left: base Qwen3.6-27B with no fine-tuning — emits
correct RTL, zero PPA awareness. Right: v3 — emits a //+ BUDGET
forecast first, then the RTL. Note the v3 sample literally writes
area_um2=0 in one of its budget lines: it follows the format
without the numbers meaning anything.
module matmul4x4 (
input clk,
input rst, // synchronous, active-high
input valid_in,
input signed [127:0] A_flat, // 16 packed int8: A[i][j] = A_flat[(i*4+j)*8 +: 8]
input signed [127:0] B_flat, // 16 packed int8: B[i][j] = B_flat[(i*4+j)*8 +: 8]
output valid_out,
output signed [255:0] C_flat // 16 packed int16: C[i][j] = C_flat[(i*4+j)*16 +: 16]
);
// -----------------------------------------------------------------
// Stage 1: Register inputs when valid_in is asserted
// -----------------------------------------------------------------
reg signed [127:0] A_reg;
reg signed [127:0] B_reg;
reg valid_d;
always @(posedge clk) begin
if (rst) begin
A_reg <= 0;
B_reg <= 0;
valid_d <= 0;
end else if (valid_in) begin
A_reg <= A_flat;
B_reg <= B_flat;
valid_d <= 1;
end else begin
valid_d <= 0;
end
end
// -----------------------------------------------------------------
// Stage 2: Combinational MAC Tree
// -----------------------------------------------------------------
// Helper to extract int8
function signed [7:0] get_A;
input [3:0] i, j;
begin
get_A = A_reg[(i*4+j)*8 +: 8];
end
endfunction
function signed [7:0] get_B;
input [3:0] i, j;
begin
get_B = B_reg[(i*4+j)*8 +: 8];
end
endfunction
// C[i][j] = sum_k ( A[i][k] * B[k][j] )
// We compute all 16 outputs in parallel.
// Each output is a sum of 4 products of int8.
// Product is int16. Sum of 4 int16s can be int18, but we truncate/wrap to int16.
wire signed [15:0] C_wire [0:15];
genvar i, j, k;
generate
for (i = 0; i < 4; i = i + 1) begin : gen_i
for (j = 0; j < 4; j = j + 1) begin : gen_j
integer idx;
initial idx = i * 4 + j;
wire signed [15:0] prod [0:3];
wire signed [17:0] sum; // Accumulate 4 int16s -> int18 max range ~ 4 * 127*127 ~ 64k, fits in 16 signed easily, but let's be safe with intermediate width
for (k = 0; k < 4; k = k + 1) begin : gen_k
assign prod[k] = get_A(i, k) * get_B(k, j);
end
assign sum = prod[0] + prod[1] + prod[2] + prod[3];
assign C_wire[idx] = sum[15:0]; // Wrap on overflow (truncate)
end
end
endgenerate
wire signed [255:0] C_comb_flat;
genvar m;
generate
for (m = 0; m < 16; m = m + 1) begin : gen_pack
assign C_comb_flat[m*16 +: 16] = C_wire[m];
end
endgenerate
// -----------------------------------------------------------------
// Stage 3: Register Output
// -----------------------------------------------------------------
reg signed [255:0] C_reg;
reg valid_out_d;
always @(posedge clk) begin
if (rst) begin
C_reg <= 0;
valid_out_d <= 0;
end else if (valid_d) begin
C_reg <= C_comb_flat;
valid_out_d <= 1;
end else begin
valid_out_d <= 0;
end
end
assign valid_out = valid_out_d;
assign C_flat = C_reg;
endmodule
//+ BUDGET v1
//+ int8_unpack ffs=0 area_um2=0 dyn_uW=0
//+ mac_layer ffs=0 area_um2=96 dyn_uW=120
//+ reg_stage ffs=392 area_um2=470 dyn_uW=35
//+ pack_out ffs=0 area_um2=0 dyn_uW=0
//+ TOTAL ffs=392 area_um2=566 dyn_uW=155
//+ ARCH: Pipeline depth 1 (Compute -> Reg -> Out) to meet 400 MHz; single-cycle register stage for all 16 outputs; no hazard forwarding (raw_in/raw_out); async reset active high; combinational compute followed by registered output.
`default_nettype none
module matmul4x4 #(
parameter WIDTH_A = 128,
parameter WIDTH_B = 128,
parameter WIDTH_C = 256
) (
input wire clk,
input wire rst,
input wire valid_in,
input wire signed [127:0] A_flat,
input wire signed [127:0] B_flat,
output wire valid_out,
output wire signed [255:0] C_flat
);
// Internal signal for the combinational result of C[i][j]
wire signed [15:0] c_comb [0:15];
wire valid_comb;
// Register file for next-cycle output
reg signed [255:0] c_reg;
reg valid_reg;
// Helper function to extract int8 from A_flat
function signed [7:0] get_a;
input [3:0] i;
input [3:0] j;
begin
get_a = A_flat[(i*4+j)*8 +: 8];
end
endfunction
// Helper function to extract int8 from B_flat
function signed [7:0] get_b;
input [3:0] i;
input [3:0] j;
begin
get_b = B_flat[(i*4+j)*8 +: 8];
end
endfunction
// Matrix Multiplication Logic
// C[i][j] = sum(A[i][k] * B[k][j]) for k in 0..3
// We calculate all 16 elements of C in parallel.
assign c_comb[0] = $signed(get_a(0,0) * get_b(0,0)) + $signed(get_a(0,1) * get_b(1,0)) + $signed(get_a(0,2) * get_b(2,0)) + $signed(get_a(0,3) * get_b(3,0));
assign c_comb[1] = $signed(get_a(0,0) * get_b(0,1)) + $signed(get_a(0,1) * get_b(1,1)) + $signed(get_a(0,2) * get_b(2,1)) + $signed(get_a(0,3) * get_b(3,1));
assign c_comb[2] = $signed(get_a(0,0) * get_b(0,2)) + $signed(get_a(0,1) * get_b(1,2)) + $signed(get_a(0,2) * get_b(2,2)) + $signed(get_a(0,3) * get_b(3,2));
assign c_comb[3] = $signed(get_a(0,0) * get_b(0,3)) + $signed(get_a(0,1) * get_b(1,3)) + $signed(get_a(0,2) * get_b(2,3)) + $signed(get_a(0,3) * get_b(3,3));
assign c_comb[4] = $signed(get_a(1,0) * get_b(0,0)) + $signed(get_a(1,1) * get_b(1,0)) + $signed(get_a(1,2) * get_b(2,0)) + $signed(get_a(1,3) * get_b(3,0));
assign c_comb[5] = $signed(get_a(1,0) * get_b(0,1)) + $signed(get_a(1,1) * get_b(1,1)) + $signed(get_a(1,2) * get_b(2,1)) + $signed(get_a(1,3) * get_b(3,1));
assign c_comb[6] = $signed(get_a(1,0) * get_b(0,2)) + $signed(get_a(1,1) * get_b(1,2)) + $signed(get_a(1,2) * get_b(2,2)) + $signed(get_a(1,3) * get_b(3,2));
assign c_comb[7] = $signed(get_a(1,0) * get_b(0,3)) + $signed(get_a(1,1) * get_b(1,3)) + $signed(get_a(1,2) * get_b(2,3)) + $signed(get_a(1,3) * get_b(3,3));
assign c_comb[8] = $signed(get_a(2,0) * get_b(0,0)) + $signed(get_a(2,1) * get_b(1,0)) + $signed(get_a(2,2) * get_b(2,0)) + $signed(get_a(2,3) * get_b(3,0));
assign c_comb[9] = $signed(get_a(2,0) * get_b(0,1)) + $signed(get_a(2,1) * get_b(1,1)) + $signed(get_a(2,2) * get_b(2,1)) + $signed(get_a(2,3) * get_b(3,1));
assign c_comb[10] = $signed(get_a(2,0) * get_b(0,2)) + $signed(get_a(2,1) * get_b(1,2)) + $signed(get_a(2,2) * get_b(2,2)) + $signed(get_a(2,3) * get_b(3,2));
assign c_comb[11] = $signed(get_a(2,0) * get_b(0,3)) + $signed(get_a(2,1) * get_b(1,3)) + $signed(get_a(2,2) * get_b(2,3)) + $signed(get_a(2,3) * get_b(3,3));
assign c_comb[12] = $signed(get_a(3,0) * get_b(0,0)) + $signed(get_a(3,1) * get_b(1,0)) + $signed(get_a(3,2) * get_b(2,0)) + $signed(get_a(3,3) * get_b(3,0));
assign c_comb[13] = $signed(get_a(3,0) * get_b(0,1)) + $signed(get_a(3,1) * get_b(1,1)) + $signed(get_a(3,2) * get_b(2,1)) + $signed(get_a(3,3) * get_b(3,1));
assign c_
Best v3 sample per design family — model forecast vs. what yosys+GNN actually measured:
| model forecast | measured (yosys+GNN) | |
|---|---|---|
| area µm² | 1190 | 1448 |
| power | 312 µW | 0.299 mW |
| Fmax | 100 (target) | 388.7 MHz |
//+ BUDGET v1
//+ mux_op ffs=0 area_um2=120 dyn_uW=40
//+ adder ffs=0 area_um2=180 dyn_uW=60
//+ suber ffs=0 area_um2=180 dyn_uW=60
//+ and_gate ffs=0 area_um2=60 dyn_uW=10
//+ or_gate ffs=0 area_um2=60 dyn_uW=10
//+ xor_gate ffs=0 area_um2=60 dyn_uW=10
//+ shl1 ffs=0 area_um2=40 dyn_uW=5
//+ shr1 ffs=0 area_um2=40 dyn_uW=5
//+ not_gate ffs=0 area_um2=40 dyn_uW=10
//+ pass_a ffs=0 area_um2=10 dyn_uW=2
//+ mux_out ffs=0 area_um2=150 dyn_uW=50
//+ overflow_gen ffs=0 area_um2=80 dyn_uW=20
//+ reg_ff ffs=17 area_um2=170 dyn_uW=40
//+ TOTAL ffs=17 area_um2=1190 dyn_uW=312
//+ ARCH: combinational datapath with synchronous register output, no pipeline, stall on n/a (pure combinational eval)
`timescale 1ns / 1ps
module alu16 (
input wire clk,
input wire rst,
input wire [3:0] op,
input wire signed [15:0] a,
input wire signed [15:0] b,
output wire signed [15:0] result,
output wire overflow
);
// Internal wires for each operation
wire signed [15:0] w_add;
wire signed [15:0] w_sub;
wire signed [15:0] w_and;
wire signed [15:0] w_or;
wire signed [15:0] w_xor;
wire signed [15:0] w_shl1;
wire signed [15:0] w_shr1;
wire signed [15:0] w_not;
wire signed [15:0] w_pass;
wire add_ov;
wire sub_ov;
// Combinational Logic
assign w_add = a + b;
assign w_sub = a - b;
assign w_and = a & b;
assign w_or = a | b;
assign w_xor = a ^ b;
assign w_shl1 = a << 1;
assign w_shr1 = a >> 1; // Arithmetic shift for signed
assign w_not = ~a;
assign w_pass = a;
// Overflow Logic
// Addition overflow: sign(a) == sign(b) != sign(result)
assign add_ov = (a[15] == b[15]) && (a[15] != w_add[15]);
// Subtraction overflow: sign(a) != sign(b) != sign(result) -> sign(a) != sign(result) when signs differ
// Standard formula: (a[15] ^ b[15]) & (a[15] ^ w_sub[15])
assign sub_ov = (a[15] ^ b[15]) && (a[15] ^ w_sub[15]);
// Combinational selection for output mux
wire signed [15:0] comb_result;
wire comb_overflow;
always @(*) begin
case (op)
4'd0: begin comb_result = w_add; comb_overflow = add_ov; end
4'd1: begin comb_result = w_sub; comb_overflow = sub_ov; end
4'd2: begin comb_result = w_and; comb_overflow = 1'b0; end
4'd3: begin comb_result = w_or; comb_overflow = 1'b0; end
4'd4: begin comb_result = w_xor; comb_overflow = 1'b0; end
4'd5: begin comb_result = w_shl1; comb_overflow = (a[14] ^ w_shl1[15]); end // Overflow if sign bit changes
4'd6: begin comb_result = w_shr1; comb_overflow = 1'b0; end // Arithmetic shift doesn't overflow in standard sense, but sign extension preserves. Safe 0.
4'd7: begin comb_result = w_not; comb_overflow = 1'b0; end
default: begin comb_result = w_pass; comb_overflow = 1'b0; end
endcase
end
// Synchronous Register
reg signed [15:0] result_reg;
reg overflow_reg;
always @(posedge clk) begin
if (rst) begin
result_reg <= 16'sb0;
overflow_reg <= 1'b0;
end else begin
result_reg <= comb_result;
overflow_reg <= comb_overflow;
end
end
assign result = result_reg;
assign overflow = overflow_reg;
endmodule| model forecast | measured (yosys+GNN) | |
|---|---|---|
| area µm² | 450 | 646 |
| power | 150 µW | 0.166 mW |
| Fmax | 200 (target) | 696.7 MHz |
//+ BUDGET v1
//+ shift_reg ffs=10 area_um2=120 dyn_uW=45
//+ clk_div ffs=10 area_um2=130 dyn_uW=50
//+ fsm ffs=2 area_um2=40 dyn_uW=10
//+ latch ffs=8 area_um2=100 dyn_uW=30
//+ control ffs=5 area_um2=60 dyn_uW=15
//+ TOTAL ffs=35 area_um2=450 dyn_uW=150
//+ ARCH: single-stage FSM, latch-based data hold, no pipeline, synchronous reset
module uart_tx (
input wire clk,
input wire rst,
input wire start,
input wire [7:0] data,
output wire tx,
output wire busy
);
// UART TX FSM states
localparam IDLE = 2'b00,
TX = 2'b01,
STOP = 2'b10;
// Register declarations
reg [1:0] state;
reg [1:0] next_state;
// Data latch (captures data on start pulse)
reg [7:0] data_reg;
// Shift register (10 bits: start + 8 data + stop)
reg [9:0] shift_reg;
// Bit counter for shifting
reg [3:0] bit_cnt;
// Clock divider for 115200 baud @ 50 MHz
// Period = 50,000,000 / 115,200 ≈ 434.03 -> use 434
localparam BAUD_DIV = 10'd434;
reg [8:0] baud_cnt;
reg baud_tick;
// Shift control: shift every baud tick during TX state
reg shift_en;
// Internal tx signal
reg tx_int;
// Busy flag
reg busy_int;
// --- Clock Divider ---
always @(posedge clk or posedge rst) begin
if (rst) begin
baud_cnt <= 9'd0;
baud_tick <= 1'b0;
end else begin
if (baud_cnt == BAUD_DIV - 1) begin
baud_cnt <= 9'd0;
baud_tick <= 1'b1;
end else begin
baud_cnt <= baud_cnt + 1'b1;
baud_tick <= 1'b0;
end
end
end
// --- FSM ---
always @(posedge clk or posedge rst) begin
if (rst) begin
state <= IDLE;
end else begin
state <= next_state;
end
end
// Next-state logic
always @(*) begin
case (state)
IDLE: begin
if (start)
next_state = TX;
else
next_state = IDLE;
end
TX: begin
if (bit_cnt == 10'd9) // After shift out last bit (stop bit)
next_state = STOP;
else
next_state = TX;
end
STOP: begin
next_state = IDLE;
end
default: next_state = IDLE;
endcase
end
// --- Data Latch ---
always @(posedge clk or posedge rst) begin
if (rst) begin
data_reg <= 8'b0;
end else if (start && state == IDLE) begin
data_reg <= data;
end
end
// --- Shift Register and Bit Counter ---
// Shift register holds: bit[0]=start(0), bits[1:8]=data, bit[9]=stop(1)
always @(posedge clk or posedge rst) begin
if (rst) begin
shift_reg <= {1'b1, 8'b0, 1'b0}; // stop=1, data=0, start=0
bit_cnt <= 4'd0;
end else begin
case (state)
IDLE: begin
shift_reg <= {1'b1, data_reg, 1'b0};
bit_cnt <= 4'd0;
end
TX: begin
if (baud_tick) begin
shift_reg <= {1'b0, shift_reg[9:1]}; // Shift left, MSB goes to tx
bit_cnt <= bit_cnt + 1'b1;
end
end
STOP: begin
shift_reg <= {1'b1, 8'b0, 1'b0};
bit_cnt <= 4'd0;
end
default: begin
shift_reg <= {1'b1, 8'b0, 1'b0};
bit_cnt <= 4'd0;
end
endcase
end
end
// --- TX Output ---
// tx_int = shift_reg[| model forecast | measured (yosys+GNN) | |
|---|---|---|
| area µm² | 185 | 287 |
| power | 65 µW | 0.061 mW |
| Fmax | 100 (target) | 1376.9 MHz |
//+ BUDGET v1
//+ comb_logic ffs=0 area_um2=140 dyn_uW=45
//+ ff_grant ffs=4 area_um2=36 dyn_uW=16
//+ ff_valid ffs=1 area_um2=9 dyn_uW=4
//+ TOTAL ffs=5 area_um2=185 dyn_uW=65
//+ ARCH: pipeline_depth=1 hazard_policy=none memory_model=combinational_logic
module prio_enc16 (
input wire clk,
input wire rst,
input wire [15:0] req,
output reg [3:0] grant,
output reg valid
);
// Combinational priority encoder logic
// Highest index (15) has priority
logic [3:0] next_grant;
logic next_valid;
always @(*) begin
if (req[15]) begin
next_grant = 4'b1111;
end else if (req[14]) begin
next_grant = 4'b1110;
end else if (req[13]) begin
next_grant = 4'b1101;
end else if (req[12]) begin
next_grant = 4'b1100;
end else if (req[11]) begin
next_grant = 4'b1011;
end else if (req[10]) begin
next_grant = 4'b1010;
end else if (req[9]) begin
next_grant = 4'b1001;
end else if (req[8]) begin
next_grant = 4'b1000;
end else if (req[7]) begin
next_grant = 4'b0111;
end else if (req[6]) begin
next_grant = 4'b0110;
end else if (req[5]) begin
next_grant = 4'b0101;
end else if (req[4]) begin
next_grant = 4'b0100;
end else if (req[3]) begin
next_grant = 4'b0011;
end else if (req[2]) begin
next_grant = 4'b0010;
end else if (req[1]) begin
next_grant = 4'b0001;
end else if (req[0]) begin
next_grant = 4'b0000;
end else begin
next_grant = 4'b0000; // Default, valid will be 0
end
next_valid = |req;
end
// Registered Output
always @(posedge clk or posedge rst) begin
if (rst) begin
grant <= 4'b0000;
valid <= 1'b0;
end else begin
grant <= next_grant;
valid <= next_valid;
end
end
endmodule| model forecast | measured (yosys+GNN) | |
|---|---|---|
| area µm² | 420 | 741 |
| power | 135 µW | 0.169 mW |
| Fmax | 400 (target) | 738.3 MHz |
//+ BUDGET v1
//+ counter_reg ffs=32 area_um2=384 dyn_uW=120
//+ wrap_logic ffs=1 area_um2=12 dyn_uW=5
//+ control_mux ffs=0 area_um2=24 dyn_uW=10
//+ TOTAL ffs=33 area_um2=420 dyn_uW=135
//+ ARCH: combinational wrap detection, synchronous reset/clear, single pipeline stage
`timescale 1ns / 1ps
module counter32 (
input wire clk,
input wire rst,
input wire en,
input wire clear,
output reg [31:0] q,
output reg wrapped
);
// Next state logic
wire [31:0] next_q;
wire next_wrapped;
assign next_q = q + 1'b1;
// Wrapped signal goes high when counter reaches MAX (all 1s) and is enabled.
// It stays high for one cycle then goes low on the wrap-around to 0.
// Or simpler: wrapped is high when the counter is about to wrap (i.e., q == 32'hFFFFFFFF) and en is high.
// Let's define wrapped as: high for one clock cycle when the counter transitions from 32'hFFFFFFFF to 0.
// However, standard practice is often: wrapped is high when the MSB changes from 1 to 0 due to increment, or simply when count == MAX.
// Let's use: wrapped is high when q == 32'hFFFFFFFF and en is high. This indicates the counter is at its max value.
// But "wrapped" usually implies the action of wrapping. Let's output a pulse when it wraps.
// Pulse: high when current q is MAX and next q is 0. Since next q is q+1, if q is MAX, next q is 0.
// So wrapped pulse is high when q == 32'hFFFFFFFF.
assign next_wrapped = en && (q == 32'hFFFFFFFF);
always @(posedge clk) begin
if (rst) begin
q <= 32'd0;
wrapped <= 1'b0;
end else if (clear) begin
q <= 32'd0;
wrapped <= 1'b0;
end else begin
if (en) begin
q <= next_q;
wrapped <= next_wrapped;
end else begin
q <= q;
wrapped <= 1'b0;
end
end
end
endmodule