@@ -236,7 +236,7 @@ struct structType {
236236
237237std::vector<functionID*> functionIDs = std::vector<functionID*>();
238238std::unordered_map<std::string, structType*> structDefinitions = std::unordered_map<std::string, structType*>();
239- std::string currentStructName = " " ;
239+ std::stack<std:: string> currentStructName = std::stack<std::string>() ;
240240
241241structType* getStructTypeFromLLVMType (Type*& t)
242242{
@@ -409,8 +409,18 @@ Type* getLLVMTypeFromString(std::string typeName, int pointerLevelOffset, tokenP
409409 // Otherwise, look in struct definitions
410410 else if (structDefinitions.find (typeName) != structDefinitions.end ()) {
411411 // If the struct body hasn't been generated yet, generate it
412- if (structDefinitions[typeName]->structVal == nullptr )
413- aType = (Type*)(structDefinitions[typeName]->sourceNode ->*(structDefinitions[typeName]->sourceNode ->codegen ))(pass);
412+ if (structDefinitions[typeName]->structVal == nullptr ) {
413+ if (currentStructName.size () == 0 || currentStructName.top () != typeName)
414+ aType = (Type*)(structDefinitions[typeName]->sourceNode ->*(structDefinitions[typeName]->sourceNode ->codegen ))(pass);
415+ // If this type is inside of a struct and the type *is* the struct,
416+ // throw an error (nested structs aren't allowed)
417+ else {
418+ wasDefined = false ;
419+ printTokenError (token, " Cannot nest struct in self" );
420+ return nullptr ;
421+ // exit(1);
422+ }
423+ }
414424 else
415425 aType = (Type*)(structDefinitions[typeName]->structVal );
416426 }
@@ -949,7 +959,47 @@ void* ASTNode::generateReturn(int pass)
949959 return nullptr ;
950960 }
951961 Value* RetVal = (Value*)(exprNode->*(exprNode->codegen ))(pass);
952- Builder->CreateRet (RetVal);
962+ // Check if we're returning a struct
963+ Type* returnType = Builder->GetInsertBlock ()->getParent ()->getReturnType ();
964+ if (returnType->isStructTy ()) {
965+ // For struct returns, we need to handle this specially
966+ // Option 1: If the function uses sret, copy to the sret parameter
967+ Function* currentFunc = Builder->GetInsertBlock ()->getParent ();
968+ if (currentFunc->hasStructRetAttr ()) {
969+ // Get the sret parameter (first parameter)
970+ Value* sretPtr = &*currentFunc->arg_begin ();
971+
972+ // Copy the struct value to the sret location
973+ if (RetVal->getType ()->isPointerTy ()) {
974+ // If RetVal is a pointer to struct, memcpy from it
975+ Value* structSize = ConstantInt::get (Type::getInt64Ty (*TheContext),
976+ TheModule->getDataLayout ().getTypeAllocSize (returnType));
977+
978+ // Create memcpy call
979+ Function* memcpyFunc = Intrinsic::getDeclaration (TheModule.get (),
980+ Intrinsic::memcpy, {sretPtr->getType (), RetVal->getType (), Type::getInt64Ty (*TheContext)});
981+ Builder->CreateCall (memcpyFunc, {sretPtr, RetVal, structSize, ConstantInt::get (Type::getInt1Ty (*TheContext), 0 )});
982+ }
983+ else {
984+ // If RetVal is a struct value, store it
985+ Builder->CreateStore (RetVal, sretPtr);
986+ }
987+ Builder->CreateRetVoid ();
988+ }
989+ else {
990+ // Option 2: Direct struct return (for small structs)
991+ if (RetVal->getType ()->isPointerTy ()) {
992+ // Load the struct value from the pointer
993+ RetVal = Builder->CreateLoad (returnType, RetVal, " struct_ret_load" );
994+ }
995+ Builder->CreateRet (RetVal);
996+ }
997+ }
998+ else {
999+ // Non-struct return, handle normally
1000+ Builder->CreateRet (RetVal);
1001+ }
1002+
9531003 return nullptr ;
9541004}
9551005
@@ -1973,7 +2023,6 @@ void* ASTNode::generateIf(int pass)
19732023void * ASTNode::generateStruct (int pass)
19742024{
19752025 std::string structName = token->first ;
1976- currentStructName = structName;
19772026
19782027 // Do not create a struct with the same name
19792028 if (structDefinitions.find (structName) != structDefinitions.end ()) {
@@ -1983,73 +2032,86 @@ void* ASTNode::generateStruct(int pass)
19832032 }
19842033 }
19852034
1986- // On pass 0, only declare the struct without it's body
1987- if (pass == 0 ) {
1988- structDefinitions[structName] = new structType (structName, token, this );
1989- return nullptr ;
1990- }
2035+ // // On pass 0, only declare the struct without it's body
2036+ // if (pass == 0) {
2037+ // structDefinitions[structName] = new structType(structName, token, this);
2038+ // return nullptr;
2039+ // }
19912040
2041+ currentStructName.push (structName);
2042+ uint8_t generatingType = 0 ; // Generate all member variables first (0), then functions (1)
19922043 argumentList members = argumentList ();
1993- std::vector<Type*> fieldTypes;
1994- std::vector<std::string> fieldNames;
2044+ std::vector<Type*> fieldTypes = std::vector<Type*>() ;
2045+ std::vector<std::string> fieldNames = std::vector<std::string>() ;
19952046 std::vector<functionID*> memberFunctions = std::vector<functionID*>();
1996- std::unordered_map<std::string, uint16_t > memberNameIndexes;
2047+ std::unordered_map<std::string, uint16_t > memberNameIndexes = std::unordered_map<std::string, uint16_t >() ;
19972048 uint16_t i = 0 ;
1998- for (auto & fieldNode : childNodes[0 ]->childNodes ) {
1999- // If it is a member variable declaration
2000- if (fieldNode->nodeType == Identifier_Node) {
2001- if (fieldNode->childNodes .size () == 0 ) {
2002- printTokenError (fieldNode->token , " Member declaration must have type" );
2003- exit (1 );
2004- }
2005- std::string memberName = fieldNode->token ->first ;
2006- ASTNode* typeNode = fieldNode;
2007- std::string memberType = fieldNode->childNodes [0 ]->token ->first ;
2008- int pointerLevel = 0 ;
2049+ for (; generatingType < 2 ; generatingType++)
2050+ for (auto & fieldNode : childNodes[0 ]->childNodes ) {
2051+ // If it is a member variable declaration
2052+ if (fieldNode->nodeType == Identifier_Node && generatingType == 0 && pass > 0 ) {
2053+ if (fieldNode->childNodes .size () == 0 ) {
2054+ printTokenError (fieldNode->token , " Member declaration must have type" );
2055+ exit (1 );
2056+ }
2057+ std::string memberName = fieldNode->token ->first ;
2058+ ASTNode* typeNode = fieldNode;
2059+ std::string memberType = fieldNode->childNodes [0 ]->token ->first ;
2060+ int pointerLevel = 0 ;
20092061
2010- recurseAddMemberPointer:
2011- typeNode = typeNode->childNodes [0 ];
2062+ recurseAddMemberPointer:
2063+ typeNode = typeNode->childNodes [0 ];
20122064
2013- if (typeNode->token ->first == " *" ) {
2014- pointerLevel++;
2015- goto recurseAddMemberPointer;
2016- }
2065+ if (typeNode->token ->first == " *" ) {
2066+ pointerLevel++;
2067+ goto recurseAddMemberPointer;
2068+ }
20172069
2018- memberType = typeNode->token ->first ;
2070+ memberType = typeNode->token ->first ;
20192071
2020- bool wasDefined = true ;
2021- Type* fieldType = getLLVMTypeFromString (memberType, 0 , typeNode->token , wasDefined, pass);
2022- // if (wasDefined == false)
2023- // return nullptr;
2024- for (int i = 0 ; i < pointerLevel; i++)
2025- fieldType = fieldType->getPointerTo ();
2026- fieldTypes.push_back (fieldType);
2027- fieldNames.push_back (memberName);
2028- memberNameIndexes[memberName] = i;
2072+ bool wasDefined = true ;
2073+ Type* fieldType = getLLVMTypeFromString (memberType, 0 , typeNode->token , wasDefined, pass);
2074+ // if (wasDefined == false)
2075+ // return nullptr;
2076+ for (int i = 0 ; i < pointerLevel; i++)
2077+ fieldType = fieldType->getPointerTo ();
2078+ fieldTypes.push_back (fieldType);
2079+ fieldNames.push_back (memberName);
2080+ memberNameIndexes[memberName] = i;
20292081
2030- members.push_back (argType (memberType, getASTNodeTypeFromString (memberType), pointerLevel));
2031- i++;
2032- }
2033- // Else it is a member function definition
2034- else if (fieldNode->nodeType == Compiler_Define_Function) {
2035- if (pass == 0 )
2036- continue ;
2037- // Generate function
2038- Function* memberFunction = (Function*)(fieldNode->*(fieldNode->codegen ))(pass);
2039- // Get pointer to generated function from global
2040- functionID* fnID = getFunctionIDFromFunctionPointer (functionIDs, memberFunction);
2041- memberFunctions.push_back (fnID);
2082+ members.push_back (argType (memberType, getASTNodeTypeFromString (memberType), pointerLevel));
2083+ i++;
2084+ }
2085+ // Else it is a member function definition
2086+ else if (fieldNode->nodeType == Compiler_Define_Function && generatingType == 1 && pass > 1 ) {
2087+ // Generate function
2088+ Function* memberFunction = (Function*)(fieldNode->*(fieldNode->codegen ))(pass);
2089+ // Get pointer to generated function from global
2090+ functionID* fnID = getFunctionIDFromFunctionPointer (functionIDs, memberFunction);
2091+ memberFunctions.push_back (fnID);
2092+ }
20422093 }
2094+ // pass 0 declare struct name,
2095+ // pass 1 struct body and function prototypes,
2096+ // pass 2 function bodies
2097+ if (pass > 1 ) {
2098+ currentStructName.pop ();
2099+ structDefinitions[structName]->members = members;
2100+ structDefinitions[structName]->memberFunctions = memberFunctions;
2101+ structDefinitions[structName]->memberNameIndexes = memberNameIndexes;
2102+ structDefinitions[structName]->structVal ->setBody (fieldTypes, false );
2103+ return nullptr ;
20432104 }
20442105
2045- // // Make nothing node to not be regenerated
2046- // nodeType = Nothing_Node;
2106+ // Make node not be regenerated
2107+ // currentNodeDoneGenerating = true;
2108+ // nodeType = Fully_Defined;
2109+ // codegen = nullptr;
20472110
2048- StructType* structTy = StructType::create (*TheContext, fieldTypes, structName);
2111+ StructType* structTy = StructType::create (*TheContext, fieldTypes, " struct. " + structName);
20492112
2050- // structTy->setBody(fieldTypes, false);
20512113
2052- currentStructName = " " ;
2114+ currentStructName. pop () ;
20532115 structDefinitions[structName] = new structType (structName, token, structTy, members, memberFunctions, memberNameIndexes);
20542116
20552117 return structTy;
@@ -2184,9 +2246,9 @@ void* ASTNode::generatePrototype(int pass)
21842246 fnName = fnName + " ." + tokenAsString (childNodes[0 ]->childNodes [0 ]->token ->second );
21852247 mangledName = fnName + " ." + tokenAsString (childNodes[0 ]->childNodes [0 ]->token ->second );
21862248 }
2187- else if (currentStructName != " " ) {
2188- fnName = currentStructName + " ." + fnName;
2189- mangledName = currentStructName + " ." + mangledName;
2249+ else if (currentStructName. size () != 0 ) {
2250+ fnName = currentStructName. top () + " ." + fnName;
2251+ mangledName = currentStructName. top () + " ." + mangledName;
21902252 isStruct = true ;
21912253 }
21922254 // else
@@ -2217,8 +2279,8 @@ void* ASTNode::generatePrototype(int pass)
22172279
22182280 // Get function arguments
22192281 // If it is a struct, first add a "this" argument like: (this : ref structName, ...)
2220- if (isStruct ) {
2221- std::string typeStr = currentStructName;
2282+ if (currentStructName. size () > 0 ) {
2283+ std::string typeStr = currentStructName. top () ;
22222284 bool isReference = true ;
22232285 int pointerLevel = 1 ;
22242286
@@ -2373,8 +2435,6 @@ void* ASTNode::generatePrototype(int pass)
23732435// Function*
23742436void * ASTNode::generateFunction (int pass)
23752437{
2376- if (pass == 0 )
2377- return nullptr ;
23782438
23792439 // First, check for an existing function from a previous declaration.
23802440 // Function* theFunction = TheModule->getFunction(token->first);
@@ -2383,13 +2443,13 @@ void* ASTNode::generateFunction(int pass)
23832443 std::string functionName = token->first ;
23842444 bool isStruct = false ;
23852445
2386- if (currentStructName != " " ) {
2387- functionName = currentStructName + " ." + functionName;
2446+ if (currentStructName. size () > 0 ) {
2447+ functionName = currentStructName. top () + " ." + functionName;
23882448 isStruct = true ;
23892449 }
23902450
23912451 // if (!theFunctionID)
2392- theFunction = (Function*)generatePrototype ();
2452+ theFunction = (Function*)this -> generatePrototype (pass );
23932453 // else
23942454 // theFunction = theFunctionID->fnValue;
23952455
@@ -2405,7 +2465,7 @@ void* ASTNode::generateFunction(int pass)
24052465 exit (1 );
24062466 }
24072467
2408- if (pass = = 1 )
2468+ if (pass < = 1 )
24092469 return theFunction;
24102470
24112471 theFunctionID = getFunctionIDFromFunctionPointer (functionIDs, theFunction);
0 commit comments