Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions front/parser/src/verification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,15 @@ fn validate_node(
scopes.pop();
}

ASTNode::ExternFunction(ext) => {
if !ext.abi.eq_ignore_ascii_case("c") {
return Err(format!(
"unsupported extern ABI '{}' for function '{}': only extern(c) is currently supported",
ext.abi, ext.name
));
}
}

_ => {}
}

Expand Down
15 changes: 9 additions & 6 deletions llvm/src/codegen/abi_c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub enum RetLowering<'ctx> {

#[derive(Clone)]
pub struct ExternCInfo<'ctx> {
pub llvm_name: String, // actual LLVM symbol name
pub wave_ret: WaveType, // Wave-level return type (needed when sret => llvm void)
pub ret: RetLowering<'ctx>,
pub params: Vec<ParamLowering<'ctx>>, // per-wave param
Expand Down Expand Up @@ -168,9 +169,9 @@ fn classify_param<'ctx>(
return ParamLowering::Direct(it.as_basic_type_enum());
}

// mixed small aggregate: safest is byval (conservative but correct)
let align = td.get_abi_alignment(&t) as u32;
return ParamLowering::ByVal { ty: t.as_any_type_enum(), align };
// mixed small aggregate: keep as direct aggregate value.
// Let LLVM's C ABI lowering split/register-assign correctly.
return ParamLowering::Direct(t);
}

// non-aggregate: direct
Expand Down Expand Up @@ -246,9 +247,9 @@ fn classify_ret<'ctx>(
}
}

// mixed small aggregate ret: safest sret
let align = td.get_abi_alignment(&t) as u32;
return RetLowering::SRet { ty: t.as_any_type_enum(), align };
// mixed small aggregate ret: keep direct aggregate value.
// Let LLVM's C ABI lowering pick mixed INTEGER/SSE return registers.
return RetLowering::Direct(t);
}

RetLowering::Direct(t)
Expand All @@ -261,6 +262,7 @@ pub fn lower_extern_c<'ctx>(
struct_types: &HashMap<String, inkwell::types::StructType<'ctx>>,
) -> LoweredExtern<'ctx> {
let llvm_name = ext.symbol.as_deref().unwrap_or(ext.name.as_str()).to_string();
let info_llvm_name = llvm_name.clone();

// wave types -> layout types
let wave_param_layout: Vec<BasicTypeEnum<'ctx>> = ext.params.iter()
Expand Down Expand Up @@ -311,6 +313,7 @@ pub fn lower_extern_c<'ctx>(
llvm_name,
fn_type,
info: ExternCInfo {
llvm_name: info_llvm_name,
wave_ret: ext.return_type.clone(),
ret,
params,
Expand Down
25 changes: 24 additions & 1 deletion llvm/src/codegen/address.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ fn addr_and_ty<'ctx>(
);

if !slot_ty.is_pointer_type() {
panic!("Cannot deref non-pointer lvalue: {:?}", inner);
// Legacy compatibility:
// allow redundant `deref` on already-addressable lvalues
// like `deref q.rear` and `deref visited[x]`.
return (slot_ptr, slot_ty);
}

let pv = load_ptr_from_slot(context, builder, slot_ptr, "deref_target");
Expand Down Expand Up @@ -383,3 +386,23 @@ pub fn generate_address_ir<'ctx>(
)
.0
}

pub fn generate_address_and_type_ir<'ctx>(
context: &'ctx Context,
builder: &'ctx Builder<'ctx>,
expr: &Expression,
variables: &mut HashMap<String, VariableInfo<'ctx>>,
module: &'ctx Module<'ctx>,
struct_types: &HashMap<String, StructType<'ctx>>,
struct_field_indices: &HashMap<String, HashMap<String, u32>>,
) -> (PointerValue<'ctx>, BasicTypeEnum<'ctx>) {
addr_and_ty(
context,
builder,
expr,
variables,
module,
struct_types,
struct_field_indices,
)
}
38 changes: 31 additions & 7 deletions llvm/src/codegen/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ use crate::codegen::abi_c::{
ExternCInfo, lower_extern_c, apply_extern_c_attrs,
};

fn is_implicit_i32_main(name: &str, return_type: &Option<WaveType>) -> bool {
name == "main" && matches!(return_type, None | Some(WaveType::Void))
}

fn is_supported_extern_abi(abi: &str) -> bool {
abi.eq_ignore_ascii_case("c")
}

pub unsafe fn generate_ir(ast_nodes: &[ASTNode], opt_flag: &str) -> String {
let context: &'static Context = Box::leak(Box::new(Context::create()));
let module: &'static _ = Box::leak(Box::new(context.create_module("main")));
Expand Down Expand Up @@ -189,11 +197,15 @@ pub unsafe fn generate_ir(ast_nodes: &[ASTNode], opt_flag: &str) -> String {
.map(|p| wave_type_to_llvm_type(context, &p.param_type, &struct_types, TypeFlavor::AbiC).into())
.collect();

let fn_type = match return_type {
None | Some(WaveType::Void) => context.void_type().fn_type(&param_types, false),
Some(wave_ret_ty) => {
let llvm_ret_type = wave_type_to_llvm_type(context, wave_ret_ty, &struct_types, TypeFlavor::AbiC);
llvm_ret_type.fn_type(&param_types, false)
let fn_type = if is_implicit_i32_main(name, return_type) {
context.i32_type().fn_type(&param_types, false)
} else {
match return_type {
None | Some(WaveType::Void) => context.void_type().fn_type(&param_types, false),
Some(wave_ret_ty) => {
let llvm_ret_type = wave_type_to_llvm_type(context, wave_ret_ty, &struct_types, TypeFlavor::AbiC);
llvm_ret_type.fn_type(&param_types, false)
}
}
};

Expand All @@ -202,6 +214,13 @@ pub unsafe fn generate_ir(ast_nodes: &[ASTNode], opt_flag: &str) -> String {
}

for ext in &extern_functions {
if !is_supported_extern_abi(&ext.abi) {
panic!(
"unsupported extern ABI '{}' for function '{}': only extern(c) is currently supported",
ext.abi, ext.name
);
}

let lowered = lower_extern_c(context, td, ext, &struct_types);

let f = module.add_function(&lowered.llvm_name, lowered.fn_type, None);
Expand Down Expand Up @@ -263,13 +282,17 @@ pub unsafe fn generate_ir(ast_nodes: &[ASTNode], opt_flag: &str) -> String {

let current_block = builder.get_insert_block().unwrap();
if current_block.get_terminator().is_none() {
let implicit_i32_main = is_implicit_i32_main(&func_node.name, &func_node.return_type);
let is_void_like = match &func_node.return_type {
None => true,
Some(WaveType::Void) => true,
_ => false,
};

if is_void_like {
if implicit_i32_main {
let zero = context.i32_type().const_zero();
builder.build_return(Some(&zero)).unwrap();
} else if is_void_like {
builder.build_return(None).unwrap();
} else {
panic!("Non-void function '{}' is missing a return statement", func_node.name);
Expand All @@ -295,6 +318,7 @@ fn pipeline_from_opt_flag(opt_flag: &str) -> &'static str {
"-O3" => "default<O3>",
"-Os" => "default<Os>",
"-Oz" => "default<Oz>",
"-Ofast" => "default<O3>",
other => panic!("unknown opt flag for LLVM passes: {}", other),
}
}
Expand Down Expand Up @@ -522,4 +546,4 @@ fn add_enum_consts_to_globals(

next += 1;
}
}
}
2 changes: 1 addition & 1 deletion llvm/src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub mod legacy;
pub mod plan;
pub mod abi_c;

pub use address::generate_address_ir;
pub use address::{generate_address_and_type_ir, generate_address_ir};
pub use format::{wave_format_to_c, wave_format_to_scanf};
pub use ir::generate_ir;
pub use types::{wave_type_to_llvm_type, VariableInfo};
Expand Down
43 changes: 42 additions & 1 deletion llvm/src/expression/rvalue/assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,27 @@ fn wave_to_basic<'ctx, 'a>(env: &ExprGenEnv<'ctx, 'a>, wt: &WaveType) -> BasicTy
wave_type_to_llvm_type(env.context, wt, env.struct_types, TypeFlavor::Value)
}

fn basic_to_wave<'ctx, 'a>(env: &ExprGenEnv<'ctx, 'a>, bt: BasicTypeEnum<'ctx>) -> Option<WaveType> {
match bt {
BasicTypeEnum::IntType(it) => {
let bw = it.get_bit_width() as u16;
if bw == 1 {
Some(WaveType::Bool)
} else {
Some(WaveType::Int(bw))
}
}
BasicTypeEnum::FloatType(ft) => Some(WaveType::Float(ft.get_bit_width() as u16)),
BasicTypeEnum::PointerType(_) => Some(WaveType::Pointer(Box::new(WaveType::Byte))),
BasicTypeEnum::ArrayType(at) => {
let elem = basic_to_wave(env, at.get_element_type())?;
Some(WaveType::Array(Box::new(elem), at.len()))
}
BasicTypeEnum::StructType(st) => Some(WaveType::Struct(resolve_struct_key(env, st))),
_ => None,
}
}

fn wave_type_of_lvalue<'ctx, 'a>(env: &ExprGenEnv<'ctx, 'a>, e: &Expression) -> Option<WaveType> {
match e {
Expression::Variable(name) => env.variables.get(name).map(|vi| vi.ty.clone()),
Expand All @@ -54,7 +75,7 @@ fn wave_type_of_lvalue<'ctx, 'a>(env: &ExprGenEnv<'ctx, 'a>, e: &Expression) ->
match inner_ty {
WaveType::Pointer(t) => Some(*t),
WaveType::String => Some(WaveType::Byte),
_ => None,
other => Some(other),
}
}
Expression::IndexAccess { target, .. } => {
Expand All @@ -66,6 +87,26 @@ fn wave_type_of_lvalue<'ctx, 'a>(env: &ExprGenEnv<'ctx, 'a>, e: &Expression) ->
_ => None,
}
}
Expression::FieldAccess { object, field } => {
let object_ty = wave_type_of_lvalue(env, object)?;
let struct_name = match object_ty {
WaveType::Struct(name) => name,
WaveType::Pointer(inner) => match *inner {
WaveType::Struct(name) => name,
_ => return None,
},
_ => return None,
};

let st = *env.struct_types.get(&struct_name)?;
let field_index = env
.struct_field_indices
.get(&struct_name)
.and_then(|m| m.get(field))
.copied()?;
let field_bt = st.get_field_type_at_index(field_index)?;
basic_to_wave(env, field_bt)
}
_ => None,
}
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/src/expression/rvalue/calls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ pub(crate) fn gen_function_call<'ctx, 'a>(

if let Some(info) = env.extern_c_info.get(name) {
let function = env.module
.get_function(name)
.get_function(&info.llvm_name)
.unwrap_or_else(|| panic!("Extern function '{}' not found in module (symbol alias?)", name));

if args.len() != info.params.len() {
Expand Down
Loading