Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ci/regression.sh.in
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,9 @@ synthesis()
PREFIX=build_base make -C hw/syn/yosys clean
PREFIX=build_base CONFIGS="-DDPI_DISABLE -DEXT_F_DISABLE -DNUM_WARPS=2 -DNUM_THREADS=2" make -C hw/syn/yosys synthesis

PREFIX=build_tcu_bhf make -C hw/syn/yosys clean
PREFIX=build_tcu_bhf CONFIGS="-DDPI_DISABLE -DEXT_F_DISABLE -DEXT_TCU_ENABLE -DTCU_BHF -DNUM_WARPS=2 -DNUM_THREADS=8" make -C hw/syn/yosys synthesis

echo "synthesis tests done!"
}

Expand Down Expand Up @@ -492,11 +495,17 @@ tensor()
CONFIGS="-DNUM_THREADS=8 -DEXT_TCU_ENABLE -DTCU_BHF" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu
CONFIGS="-DNUM_THREADS=8 -DEXT_TCU_ENABLE -DTCU_DSP" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu

make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=8 -DITYPE=fp16 -DOTYPE=fp16" make -C tests/regression/sgemm_tcu
CONFIGS="-DNUM_THREADS=8 -DEXT_TCU_ENABLE -DTCU_BHF" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu

make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=16 -DITYPE=bf16 -DOTYPE=fp32" make -C tests/regression/sgemm_tcu
CONFIGS="-DNUM_THREADS=16 -DEXT_TCU_ENABLE -DTCU_DPI" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu
CONFIGS="-DNUM_THREADS=16 -DEXT_TCU_ENABLE -DTCU_BHF" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu
CONFIGS="-DNUM_THREADS=16 -DEXT_TCU_ENABLE -DTCU_DSP" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu

make -C tests/regression/sgemm_tcu clean && CONFIGS="-DNUM_THREADS=16 -DITYPE=bf16 -DOTYPE=bf16" make -C tests/regression/sgemm_tcu
CONFIGS="-DNUM_THREADS=16 -DEXT_TCU_ENABLE -DTCU_BHF" ./ci/blackbox.sh --driver=rtlsim --app=sgemm_tcu

echo "tensor tests done!"
}

Expand Down
169 changes: 156 additions & 13 deletions hw/rtl/tcu/VX_tcu_fedp_bhf.sv
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
// limitations under the License.

`include "VX_define.vh"
`include "HardFloat_consts.vi"

module VX_tcu_fedp_bhf #(
module VX_tcu_fedp_bhf import VX_tcu_pkg::*; #(
parameter LATENCY = 1,
parameter N = 1
) (
Expand Down Expand Up @@ -41,8 +42,9 @@ module VX_tcu_fedp_bhf #(
localparam FMT_DELAY = FMUL_LATENCY + FRND_LATENCY;
localparam C_DELAY = (FMUL_LATENCY + FRND_LATENCY) + 1 + FRED_LATENCY;

`UNUSED_VAR ({fmt_d, c_val});

`ifdef XLEN_64
`UNUSED_VAR (c_val[63:32]);
`endif
wire [2:0] frm = '0; // RNE rounding mode

wire [TCK-1:0][15:0] a_row16;
Expand Down Expand Up @@ -120,9 +122,9 @@ module VX_tcu_fedp_bhf #(

logic [32:0] mult_result_mux;
always_comb begin
case(fmt_s_delayed)
3'd1: mult_result_mux = mult_result_fp16;
3'd2: mult_result_mux = mult_result_bf16;
case (fmt_s_delayed)
TCU_FP16_ID: mult_result_mux = mult_result_fp16;
TCU_BF16_ID: mult_result_mux = mult_result_bf16;
default: mult_result_mux = 'x;
endcase
end
Expand Down Expand Up @@ -173,17 +175,73 @@ module VX_tcu_fedp_bhf #(

// Accumulation input C recoding and delay handling

wire [32:0] c_rec, c_delayed;
wire [31:0] result;
wire [16:0] c_fp16_rec, c_bf16_rec;
wire [32:0] c_fp32_rec, c_fp16_to_fp32_rec, c_bf16_to_fp32_rec;
logic [32:0] c_rec;
wire [32:0] c_delayed;

fNToRecFN #(
.expWidth (8),
.sigWidth (24)
) conv_c (
) conv_c_fp32 (
.in (c_val[31:0]),
.out (c_rec)
.out (c_fp32_rec)
);

fNToRecFN #(
.expWidth (5),
.sigWidth (11)
) conv_c_fp16 (
.in (c_val[15:0]),
.out (c_fp16_rec)
);

// Match the BHF fadd/fmul HardFloat tininess policy.
wire control = `flControl_tininessAfterRounding; // IEEE 754-2008

recFNToRecFN #(
.inExpWidth (5),
.inSigWidth (11),
.outExpWidth (8),
.outSigWidth (24)
) widen_c_fp16 (
.control (control),
.in (c_fp16_rec),
.roundingMode (frm),
.out (c_fp16_to_fp32_rec),
`UNUSED_PIN (exceptionFlags)
);

fNToRecFN #(
.expWidth (8),
.sigWidth (8)
) conv_c_bf16 (
.in (c_val[15:0]),
.out (c_bf16_rec)
);

recFNToRecFN #(
.inExpWidth (8),
.inSigWidth (8),
.outExpWidth (8),
.outSigWidth (24)
) widen_c_bf16 (
.control (control),
.in (c_bf16_rec),
.roundingMode (frm),
.out (c_bf16_to_fp32_rec),
`UNUSED_PIN (exceptionFlags)
);

always_comb begin
case (fmt_d)
TCU_FP32_ID: c_rec = c_fp32_rec;
TCU_FP16_ID: c_rec = c_fp16_to_fp32_rec;
TCU_BF16_ID: c_rec = c_bf16_to_fp32_rec;
default: c_rec = 'x;
endcase
end

VX_pipe_register #(
.DATAW (33),
.DEPTH (C_DELAY)
Expand All @@ -195,12 +253,28 @@ module VX_tcu_fedp_bhf #(
.data_out(c_delayed)
);

wire [2:0] fmt_d_delayed;

VX_pipe_register #(
.DATAW (3),
.DEPTH (TOTAL_LATENCY)
) pipe_fmt_d (
.clk (clk),
.reset (reset),
.enable (enable),
.data_in (fmt_d),
.data_out(fmt_d_delayed)
);

// Final accumulation

wire [32:0] result_rec;

VX_tcu_bhf_fadd #(
.IN_EXPW (8),
.IN_SIGW (23+1),
.IN_REC (1), // input in recoded format
.OUT_REC (0), // output in IEEE format
.OUT_REC (1), // output in recoded format
.ADD_LATENCY (FADD_LATENCY),
.RND_LATENCY (FRND_LATENCY)
) final_add (
Expand All @@ -210,10 +284,79 @@ module VX_tcu_fedp_bhf #(
.frm (frm),
.a (red_in[LEVELS][0]),
.b (c_delayed),
.y (result),
.y (result_rec),
`UNUSED_PIN(fflags)
);

assign d_val = `XLEN'(result);
wire [31:0] result_fp32;
wire [16:0] result_fp16_rec, result_bf16_rec;
wire [15:0] result_fp16, result_bf16;

recFNToFN #(
.expWidth (8),
.sigWidth (24)
) to_fp32 (
.in (result_rec),
.out (result_fp32)
);

recFNToRecFN #(
.inExpWidth (8),
.inSigWidth (24),
.outExpWidth (5),
.outSigWidth (11)
) narrow_result_fp16 (
.control (control),
.in (result_rec),
.roundingMode (frm),
.out (result_fp16_rec),
`UNUSED_PIN (exceptionFlags)
);

recFNToFN #(
.expWidth (5),
.sigWidth (11)
) to_fp16 (
.in (result_fp16_rec),
.out (result_fp16)
);

recFNToRecFN #(
.inExpWidth (8),
.inSigWidth (24),
.outExpWidth (8),
.outSigWidth (8)
) narrow_result_bf16 (
.control (control),
.in (result_rec),
.roundingMode (frm),
.out (result_bf16_rec),
`UNUSED_PIN (exceptionFlags)
);

recFNToFN #(
.expWidth (8),
.sigWidth (8)
) to_bf16 (
.in (result_bf16_rec),
.out (result_bf16)
);

logic [31:0] result;

always_comb begin
case (fmt_d_delayed)
TCU_FP32_ID: result = result_fp32;
TCU_FP16_ID: result = {16'b0, result_fp16};
TCU_BF16_ID: result = {16'b0, result_bf16};
default: result = 'x;
endcase
end

`ifdef XLEN_64
assign d_val = {32'hffffffff, result};
`else
assign d_val = result;
`endif

endmodule
2 changes: 1 addition & 1 deletion hw/rtl/tcu/VX_tcu_top.sv
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ module VX_tcu_top import VX_gpu_pkg::*, VX_tcu_pkg::*; #(
);
VX_execute_if #(
.data_t (tcu_exe_t)
) VX_execute_if();
) execute_if();

VX_result_if #(
.data_t (tcu_res_t)
Expand Down
2 changes: 2 additions & 0 deletions hw/rtl/tcu/VX_tcu_uops.sv
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ module VX_tcu_uops import
assign ibuf_out.PC = ibuf_in.PC;
assign ibuf_out.ex_type = ibuf_in.ex_type;
assign ibuf_out.op_type = ibuf_in.op_type;
assign ibuf_out.op_args.tcu.__padding = '0;
assign ibuf_out.op_args.tcu.fmt_s = ibuf_in.op_args.tcu.fmt_s;
assign ibuf_out.op_args.tcu.fmt_d = ibuf_in.op_args.tcu.fmt_d;
assign ibuf_out.op_args.tcu.step_m = 4'(m_index);
Expand All @@ -99,6 +100,7 @@ module VX_tcu_uops import
`UNUSED_VAR (ibuf_in.rs1)
`UNUSED_VAR (ibuf_in.rs2)
`UNUSED_VAR (ibuf_in.rs3)
`UNUSED_VAR (ibuf_in.op_args.tcu.__padding)

reg busy;

Expand Down
13 changes: 13 additions & 0 deletions hw/syn/yosys/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ endif
RTL_INCLUDE = -I$(RTL_DIR) -I$(RTL_DIR)/libs -I$(RTL_DIR)/interfaces -I$(RTL_DIR)/core -I$(RTL_DIR)/mem -I$(RTL_DIR)/cache
RTL_INCLUDE += $(FPU_INCLUDE)

# Add TCU extension sources
ifneq (,$(findstring -DEXT_TCU_ENABLE, $(CONFIGS)))
RTL_INCLUDE += -I$(RTL_DIR)/tcu
ifneq (,$(findstring -DTCU_DRL, $(CONFIGS)))
RTL_INCLUDE += -I$(RTL_DIR)/tcu/drl
endif
ifneq (,$(findstring -DTCU_BHF, $(CONFIGS)))
RTL_INCLUDE += -I$(RTL_DIR)/tcu/bhf
endif
RTL_INCLUDE += -J$(THIRD_PARTY_DIR)/hardfloat/source/RISCV
RTL_INCLUDE += -I$(THIRD_PARTY_DIR)/hardfloat/source
endif

# Debugging
ifdef DEBUG
CFLAGS += $(DBG_TRACE_FLAGS)
Expand Down
16 changes: 11 additions & 5 deletions tests/regression/sgemm_tcu/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ class Comparator<vt::int32> {
}
};

static int32_t ulp_distance16(uint16_t a, uint16_t b) {
return std::abs(static_cast<int32_t>(a) - static_cast<int32_t>(b));
}

template <>
class Comparator<vt::fp16> {
public:
Expand All @@ -207,9 +211,10 @@ class Comparator<vt::fp16> {
return rv_ftoh_s(bit_cast<uint32_t>(fvalue), 0, nullptr);
}
static bool compare(uint16_t a, uint16_t b, int index, int errors) {
if (a != b) {
auto d = ulp_distance16(a, b);
if (d > FLOAT_ULP) {
if (errors < MAX_ERRORS) {
printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a);
printf("*** error: [%d] expected=0x%x, actual=0x%x, ulp=%d\n", index, b, a, d);
}
return false;
}
Expand All @@ -225,9 +230,10 @@ class Comparator<vt::bf16> {
return rv_ftob_s(bit_cast<uint32_t>(fvalue), 0, nullptr);
}
static bool compare(uint16_t a, uint16_t b, int index, int errors) {
if (a != b) {
auto d = ulp_distance16(a, b);
if (d > FLOAT_ULP) {
if (errors < MAX_ERRORS) {
printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a);
printf("*** error: [%d] expected=0x%x, actual=0x%x, ulp=%d\n", index, b, a, d);
}
return false;
}
Expand Down Expand Up @@ -686,4 +692,4 @@ int main(int argc, char *argv[]) {
std::cout << "PASSED!" << std::endl;

return 0;
}
}
Loading