Introduction

Single-cell RNA sequencing (scRNA-seq) technology reveals diverse cellular function and characteristics, making it crucial for in-depth research in fields such as developmental biology, cancer, immunology, and neuroscience1,2,3.

Standard processes for scRNA-seq data analysis can be divided into upstream analysis and downstream analysis4,5,6. After single-cell isolation and library construction, raw sequencing reads are transformed into gene expression profiles using standard processing pipelines, such as CellRanger7. The upstream process involves preprocessing and visualising scRNA-seq data8, including normalisation and scaling, to eliminate technical noise and reduce data dimensionality9,10. The cells were then clustered based on similarities in gene expression patterns11. Downstream analysis, which includes cell, gene, and pathway-level investigations, plays a critical role in interpreting biological mechanisms. Accurate cell-type annotation is a key aspect of downstream analysis, essential for understanding the cellular composition of tissues, identifying novel cell subpopulations, and discovering potential biomarkers. At the cellular level, cell-type annotation is a key step in classifying cells into specific cell types or states,providing insights into tissue composition. In addition, cell trajectory analysis and cell-to-cell communication are important analytical tools. Gene-level analysis focus on describing differentially expressed genes, constructing gene regulatory networks, and conducting gene enrichment analysis. At the pathway level, pathway enrichment tools are usually used to enrich the pathways of different cell types in scRNA-seq data. Overall, cell-type annotation is a key link connecting upstream and downstream analysis, playing a vital role in understanding cell identity and function12,13.

Currently, cell-type annotation methods are mainly divided into three types:

  • Traditional clustering and annotation, which involves finding marker genes for known cells or using expert prior knowledge for annotation14,15.

  • Cell-type annotation using supervised16, semi-supervised, or unsupervised machine learning classification methods17.

  • With the significant progress of deep learning in computer vision and natural language processing, various deep learning models have also been applied to the analysis of scRNA-seq data, and have achieved promising results18,19.

In the field of cell-type annotation, the first approach of clustering followed by annotation usually relies on the characteristic marker genes for each type of cell identified in published literature20,21. However, the selection of marker genes largely depends on a researchers’ prior knowledge, leading to potential biases and inaccuracies in the annotation process. Additionally, this manual labelling method is time-consuming and laborious. Most cell types contain multiple marker genes.22. Without suitable methods for integrating the expression information of these multiple marker genes, accurately assigning each cluster to a specific cell type may be challenging23.

In recent years, the application of deep learning technologies in the analysis of scRNA-seq data has been gradually increasing and has demonstrated outstanding analytical performance24,25,26,27. The methods used include autoencoders28,29,30, graph neural networks31,32, and deep learning models based on the self-attention mechanism33,34, etc. Notably, the Transformer model, which avoids dimensionality reduction35,36 and introduces a powerful self-attention mechanism along with the ability to integrate long sequence information, has made significant progress in the field of natural language processing, as demonstrated by models such as bidirectional encoder representations from transformers (BERT)37 and generative pre-trained transformers38. However, in the traditional transformer model, the self-attention mechanism compares each element in the input data with all other elements to generate self-attention scores, leading to computation time and space complexities that are both proportional to the square of the sequence length. This requires substantial computational resources and a relatively long training process.

To optimise computer resource usage, improve computational speed of the model, and ensure the accuracy and broad applicability of the model, this study proposes a new cell-type annotation model for scRNA-seq data, named general gated axial-attention model for accurate cell-type annotation of single-cell RNA-seq data (scGAA). First, we analyzed the computational bottleneck of the self-attention mechanism when processing long-sequence data39,40. In the traditional transformer model, the computational complexity of the self-attention mechanism and its memory requirements increase quadratically with an increase in the input sequence length. This property makes it computationally expensive to directly apply the model to long sequences, such as gene expression in scRNA-seq data.To address these issues, the axial self-attention mechanism was introduced, which reduces computational demands by performing self-attention calculations in both horizontal and vertical directions. Second, the scGAA model incorporates six novel gating units designed to enhance its adaptability and performance across scRNA-seq datasets of varying sizes. These gating units dynamically adjust the query, key, and value matrices within the model’s horizontal and vertical attention mechanisms, allowing it to flexibly focus on relevant information based on the specific characteristics of the dataset. This adaptability enables the model to prioritize genes critical for cell type or state identification while effectively filtering out background noise.

The proposed method focuses on exploring gene interactions and uses these relationships to annotate cell types. Our main contributions are as follows:

  • The scGAA model first introduced a gating mechanism into the field of cell type annotation. The scGAA model controls information flow, effectively extracts key features, and significantly improves the prediction performance of the model.

  • The scGAA model adopts a multi-angle learning strategy by integrating the horizontal and vertical attention mechanisms. This effectively captures gene interactions while optimizing computational efficiency and reducing complexity.

  • The scGAA model combines axial attention with gating structures, and does not require any batch information, enabling scGAA to annotate scRNA-seq data without considering batch differences while retaining heterogeneity between cells.

  • The scGAA model also balances the dataset to avoid the problem of weak model generalization ability due to imbalanced data types, further enhancing the robustness of the model.

Results

Overview of scGAA

The scGAA (Fig. 1) is a cell-type annotation model based on the gated axial-attention mechanism, which is trained through supervised learning and uses the transformer model as a foundation. The model employs an axial self-attention mechanism and gating units to capture the interactions between any two genes. It extracts global information without using positional information. By learning the interactions between genes and their expression patterns, each cell type can be mapped.

Fig. 1
figure 1

scGAA is a cell type annotation model based on gated axial-attention. After obtaining the raw data using traditional sequencing methods, the data is preprocessed to obtain the gene expression matrix. Subsequently, random masking is applied to mitigate the impact of potential “dropout” events in the scRNA-seq data. Next, all genes are randomly grouped into different gene sets, and these gene sets are embedded for each cell. A fully connected weight matrix is generated and used as input to the attention mechanism to compute the attention score matrix, which is then horizontally and vertically segmented. The segmented matrix is fed into a gated axial-attention mechanism for model training. Finally, the model’s prediction results are used in conjunction with a predefined threshold to determine the cell types.

Accuracy and robustness of scGAA models

Initially, we aims to compare three models (scBERT41, TOSICA42 and scGAA) in their performance of predicting cell types in scRNA-seq data. To validate the accuracy and robustness of the scGAA model, we employed datasets of six different tissues (kidney, pancreas, liver, brain, lung, and heart) that included over 130 cell types in total (Supplementary Table 1). Simultaneously, we conducted a comparison with the most recently released cell-type annotation models in this field (TOSICA and scBERT). We individually computed the accuracy and macro F1 score of each model as evaluation metrics. The scGAA demonstrated the highest accuracy for each dataset, whereas TOSICA and scBERT had their respective strengths across different datasets . We observed that all models had a relatively low accuracy on the liver and kidney datasets, likely associated with a greater number of cell types but fewer cell counts in these datasets. For instance, the kidney dataset contained 26 cell types with a total cell count of 5,738. Lower cell counts could have resulted in the models’ lack of sufficient learning of the features of each cell type during the training process, indicating poor prediction effects on cell types in small, complex datasets. In this scenario, the scGAA model still demonstrated the highest accuracy of 92.33%, as shown in Fig. 2a. Simultaneously, we compared the macro F1 score of different the models across the six datasets, and the results showed that the scGAA model’s macro F1 score was substantially superior to that of the other two models (Fig. 2b). This suggests that the scGAA model has a better generalization capability, adapts to datasets of different tissue types, and provides reliable and accurate cell type predictions. The specific results are illustrated in Fig. 2c. A heat map of the accuracy of the scGAA model for cell-type annotation in other datasets is shown in Supplementary Fig. 1a, Fig. 2a, Fig. 3a, Fig. 4a, Fig. 5a, Fig. 6a. The results are presented in Supplementary Table 2.

This study utilized the scBERT, TOSICA, and scGAA models to predict various cell types in a heart dataset. The results are presented in Fig. 2e. We calculated the average and standard deviation of the prediction accuracy for each cell type across the three models, as shown in Table 1. A comparison of the results revealed that the scGAA model had the highest average prediction accuracy, reaching 0.971. The average prediction accuracies of the scBERT and TOSICA models are 0.959 and 0.951, respectively, which are both lower than those of the scGAA model. Concurrently, the scGAA model had the lowest standard deviation of 0.025, suggesting that the prediction results of the scGAA model are more stable and consistent. Notably, it demonstrates superior recognition and distinguishing abilities for cell types that are typically difficult to differentiate. The results of the scGAA model cell-type annotation for the heart dataset are indicated in Fig. 2f. Annotation results for the other datasets are presented in Supplementary Fig. 1b, Fig. 2b, Fig. 3b, Fig. 4b, Fig. 5b and Fig. 6b.

Table 1 The average and standard deviation of prediction accuracy for each cell type across three models.

Concurrently, we employed scGAA to perform a clustered heatmap analysis of the correlation matrix derived from the prediction results of cell types in the liver dataset (Fig. 2d). This revealed the correlation between the model’s predictions and actual results and serves as a measure for evaluating the model’s prediction accuracy. Notably many subtypes of similar major cell types were grouped together and exhibited a relatively strong correlation. For instance, there is certain correlation between MHCII high CD14 monocyte and CD14 monocyte. This suggests that the scGAA model perceives these two cell types as having some common attributes, or that they share similarities at the gene expression level. This might be due to both cell types being subtypes of monocytes, or them having overlapping biological functions. Subtypes from the same larger cell type category, like NK cell, CD56 NK cell, etc., are clustered together, with these subtypes also displaying a high level of correlation. This indicates that the scGAA model can recognize the hierarchical organization of cells and aggregate close cell subtypes. Additionally, this reflects the advantage of the scGAA model’s self-attention mechanism, which can capture both long and short-range dependencies between cells, thereby enhancing its resolution capability.

Fig. 2
figure 2

Comparison of cell type annotation capabilities of different models. (a) Histogram of average accuracy of each cell type annotation model in different data sets. (b) Macro F1 Score heat map of each cell type annotation model in different data sets. (c) In different data sets, the box plot of the accuracy of predicting each cell type by each cell type annotation model. (d) Correlation matrix clustering heat map of predicted results and true results of scGAA model on liver data set. (e) Heatmap of prediction accuracy for all cell types in the heart dataset across different models. (f) Cell type annotation results of the scGAA model for the heart dataset.

The scGAA can discover new cell types

In this study, a preset threshold (95%) was adopted for cell type prediction. Only when the probability of a cell being predicted as a specific type exceeds this threshold will the model classify the cell as the predicted type. If the prediction probability is below the set threshold, the cell is annotated as “unknown”. We simulated the loss of key cell types using pancreas (GSE148073) and heart (GSE216019) datasets and tested the model’s ability to accurately predict them. Additionally, to further validate the functionality of the gating structure in the scGAA model, ablation experiments were conducted. This involved comparing the scGAA-axi (scGAA model without gating units), scGAA, scBERT, and TOSICA models. To evaluate the performance of the models, we calculated confusion matrices for each model, comparing the predicted cell types to the actual cell types. We visualized these confusion matrices as heatmaps, where the x-axis represents the predicted cell type and the y-axis represents the actual cell type.

Fig. 3
figure 3

scGAA demonstrates superior capability in identifying novel cell types. In the pancreas dataset, this paper predicts the accuracy of each cell type after deleting alpha cells and performing similar ablation experiments. Heatmaps depict the TOSICA model (a), the scBERT model (b), the scGAA-axi model (c), and the scGAA model (d), respectively. The x-axis represents the predicted cell types, and the y-axis represents the actual cell types.

In the pancreatic dataset, we simulated the loss of alpha cells by deleting them and tested the model’s performance. The results indicated that the average prediction accuracy for non-alpha cells in the pancreatic dataset was 96.2%. Among cells predicted as “unknown”, the proportion of those predicted as alpha cells was significantly higher than other cell types. In the scGAA-axi model, the proportion of alpha cells predicted as “unknown” was 27% (Fig. 3c). In the scGAA model, this proportion was 36% (see Fig. 3d). Conversely, the alpha cell prediction of the scBERT model proportion was 33% (Fig. 3b). Notably, in the TOSICA model, the largest proportion of cells predicted as “unknown” were pericytes, approximately 74%, with alpha cells at 61% (Figs. 3a, 4a). Although the predicted proportion of alpha cells was higher than that of the other methods, the largest proportion was erroneous pericytes. Similarly, in the heart dataset, we deleted ventricular cardiomyocytes and performed similar ablation experiments. The result revealed an average prediction accuracy for non-ventricular cardiomyocyte cells as 92.7%, and among the cells predicted as “unknown”, the proportion predicted as ventricular myocardial cells was significantly higher than other cell types. In the scGAA-axi model, the proportion of ventricular cardiomyocyte cells predicted as “unknown” was 56% (Supplementary Fig. 7c), in the scGAA model, this ratio was 64% (Supplementary Fig. 7d), In contrast, the scBERT model’s prediction proportion for ventricular myocardial cells was 60.3% (Supplementary Fig. 7b), while in the TOSICA model, it was 51% (Supplementary Fig. 7a). The results are presented in Supplementary Table 3. Thus, in the analysis of the pancreatic and cardiac datasets, the scGAA model performed best in discovering new cell types.

Fig. 4
figure 4

The ability to identify new cell types. (a) Heat map comparing the ability of different models to discover new cell types. (b) In the pancreas dataset, the alpha cell type is removed to evaluate the effect of the scGAA model in predicting other cell types. Then, UMAP is used to visualize the actual results (left panel) and the model prediction results (right panel).

To illustrate the prediction results more directly, we contrasted UMAP visualizations of original data expert cell type annotations and new cell types discovered via the scGAA model, colored by different cell types (Fig. 4b, Supplementary Fig. 8b). The visualizations clearly demonstrate the successful identification by the scGAA model of the removed alpha cells and ventricular cardiomyocytes as the “unknown” cell type.

Furthermore, we examined the prediction scores (Fig. 5a) for each cell type in the pancreatic dataset during the scGAA model training. We compared the known cell types with those predicted by scGAA (Fig. 5b). Except for alpha cells, the Sankey diagram shows that the cell types were accurately assigned to their corresponding cell types. Alpha cells had the highest proportion of new cell types. These results indicated that gating mechanisms facilitate the discovery of novel cell types not present in the original reference dataset.

Fig. 5
figure 5

Attention scores and predicted cell distribution by cell type in the pancreas dataset. (a) The attention scores of scGAA for predicting different cell types in the pancreas dataset. (b) Sankey plot of flow between scGAA model predicted cell types and true cell types.

Interpretability analysis of the scGAA model

We conducted an in-depth analysis of gene attention scores within the scGAA model to identify key the genes contributing to cell type prediction. First, we used a pancreatic dataset as an example to analyze the expression characteristics of different genes in pancreatic cells43. The attention scores of each gene were calculated to reflect their importance in cell function. The attention scores are sorted from high to low, and the top five genes are selected for expression heat map to demonstrate their importance (Fig. 6a,b). For instance, in acinar cells, the CELA3A gene had a higher attention score and correspondingly, a higher expression level. Genes such as GCG, ARX, and CRYBA2 in alpha cells, the insulin coding gene INS in beta cells, the transcription factor NKX6-1 gene closely related to the formation and function of Beta cells, and the COL1A2 and COL6A2 genes in Endothelial cells Fig. 6d, also show higher expression levels and attention scores in their respective cell types44,45. These genes with high attention scores play a key role in identifying the biological characteristics of different cell types, showing a positive correlation between attention scores and their expression levels in their respective cell types. This indicates that the scGAA model can effectively identify genes with higher expression levels in cells and assign them higher attention scores, reflecting their importance in cell function.

Second, we focused on genes that exhibited high attention scores in various cell types, to further explore their importance in cell functions and related biological processes46. Using beta cells as an example, we utilized Gene Ontology and Kyoto Encyclopedia of Genes and Genomes pathway analysis to categorize, annotate, and functionally analyze these genes47,48. Enrichment analysis was performed on the top 100 high-scoring genes in each cell type from three aspects: biological processes, molecular functions, and cellular components. The results revealed that these genes play a substantial role in biological processes with specific functions, as well as their roles in specific cellular components. According to the enrichment results, there is a certain relationship between the significantly enriched pathways and the corresponding cell types. For example Fig. 6c, in terms of biological processes, in beta cells, which are responsible for insulin production and secretion, biological processes related to insulin secretion and transport (“insulin processing” and “insulin secretion”) have significant scores and lower P values in the enrichment analysis. This indicates that genes with higher attention scores have a direct relationship with the synthesis, maturation, and secretion of insulin. In terms of cellular components, the synthesis and transport of insulin need to be completed with the help of transport vesicles and the endoplasmic reticulum lumen.

By analyzing the enrichment results of genes with higher scores in beta cells, we found that these genes were closely related to the function and structure of beta cells. Therefore, these higher-scoring genes play an important role in the annotation of cell types, indicating that they are critical for understanding the specific functional and structural properties of cells.

Fig. 6
figure 6

Model interpretability. (a) Attention scores were ranked in descending order, and the top 5 genes were selected to draw the expression heat map. (b) A gene with a high attention score was selected from each cell type in the pancreas dataset for violin stacked plot presentation. (c) Bubble plot of enrichment analysis of the 100 genes with the highest attention scores in beta cells. (d) A gene with a high attention score in each cell type of the pancreas dataset was selected for t-SNE plot presentation.

Performance comparison

To validate the performance of the proposed scGAA model, this study conducted model training on four datasets (GSE135893, GSE222007, GSE216019, GSE136103). During this period, the four datasets were compared in terms of the training time, number of training rounds, accuracy, and training loss scores. At the same time, the results of the scGAA model were compared with the existing TOSICA model. The experimental results suggested that the scGAA model outperformed the TOSICA model for all indicators across the four datasets, strongly demonstrating the superiority of the scGAA model.

Fig. 7
figure 7

Comparison of computational efficiency and model performance of TOSICA, scGAA methods. (a) The graph shows the accuracy (Train_Acc) versus the number of training rounds (Epoch) for both the TOSICA and the scGAA models. (b) The figure shows the graphs of training loss (Train_Loss) versus the number of training rounds (Epoch) for the TOSICA and scGAA models. (c) The figure shows the graphs of accuracy (Train_Acc) versus training time (Time) for TOSICA and scGAA models.(d) The graph is a plot of the training loss (Train_Loss) versus training time (Time) for the TOSICA and scGAA models.

Specifically with the dataset GSE136103, the training loss of the scGAA model was gradually reduced from the initial 1.777 to 0.079, whereas the training accuracy increased from 34.61% to 97.83%. As the number of training rounds increased, the accuracy of both models increased. After the ninth round the accuracy of the scGAA model was very close to the optimum. At this point, the accuracy of the TOSCICA model was significantly lower than that of the scGAA model, and continued to increase gradually in the subsequent training process. In contrast, the training loss of the TOSICA model decreased from 3.147 to 0.221 over the same period, and the training accuracy increased from 4.45% to 93.78% (Fig. 7a, ). As the number of training rounds increased, the training loss of both models decreased. Meanwhile, the training loss value of the scGAA model gradually stabilised after the ninth round, while the accuracy of the TOSICA model gradually decreased. Therefore, although both models exhibited improved performance, the scGAA model consistently exhibited higher accuracy and lower loss scores throughout the training process. More importantly, at the end of the same training period, the scGAA model outperformed the TOSICA model in terms of training loss and accuracy. For example, at the last observation point (scGAA: 136.55 min, TOSICA: 238.75 min), the training loss of the scGAA method (0.079) was significantly lower than that of the TOSICA method (0.221), and the training accuracy of the scGAA method (97.83%) outperformed that of the TOSICA method (93.78%) (Fig. 7c,d). In addition, the performance comparison with scBERT can be viewed in Supplementary Fig. 9.

Discussion

In this study, we introduce scGAA, an innovative and efficient tool for cell-type annotation, leveraging the gated axial-attention mechanism. The axial self-attention mechanism considerably reduces the computational complexity and memory requirements of the model by performing self-attention calculations in the horizontal and vertical directions of the sequence. Our study further optimized the transformer model to adapt it to scRNA-seq datasets of different sizes and improved the model’s performance in cell type or state identification. To achieve this, we introduced six innovative gating units to dynamically adjust the query, key, and value matrices in the horizontal and vertical attention mechanisms of the model. The design of these gating units is based on the core concept that the model should be able to flexibly adjust its focus to information based on the characteristics and requirements of the dataset. This not only improves the accuracy of the model in cell annotation tasks but also markedly enhances its robustness, enabling it to maintain stable performance in a changing data environment.(Figs. 3, 4 and Supplementary Fig. 9)

The application of scGAA across various scRNA-seq datasets has demonstrated its exceptional performance, being utilized on six different datasets covering major organs such as the kidney, pancreas, liver, brain, lung, and heart. These datasets collectively encompass more than 130 cell types, demonstrating the extensive capabilities and robustness of scGAA. This study employed metrics such as accuracy and the macro F1 score to assess the model’s performance by conducting a comprehensive comparison with the latest methods (Fig. 2 and Supplementary Fig. 9). The results revealed that scGAA consistently outperformed the comparative methods in terms of accuracy, both across individual datasets and in identifying each cell type. This evidence of the model’s excellent generalization ability and robustness provides biologists with a powerful cellular tool, aiding them in a deeper understanding of cellular diversity and complexity.

The scGAA model has excellent interpretability and provides important insights for identifying key genes that drive cell-type prediction. By analyzing the attention weights of genes, we can determine which genes play a key role in the prediction results and which genes have important interactions, thereby gaining a deeper understanding of cell the function and state (Fig. 6 and Supplementary Fig.10). In addition, this approach can identify the most common genes in a specific cell type, enabling researchers to explore the cellular mechanisms and potential biomarkers in-depth.

Compared to the TOSICA and scBERT models, the scGAA model not only demonstrates higher computational efficiency and superior model performance but also achieves the best convergence speed and excellent generalization ability (Fig. 7 and Supplementary Fig. 9). This result validates the unique advantages of the gated axial-attention mechanism introduced in the design of the scGAA model for handling scRNA-seq data. Notably, the scGAA model also exhibited advantages in discovering novel cell types, further highlighting the exceptional scalability of our model in the field of cell-type annotation.

Although the scGAA model demonstrated strong performances, several factors may limit its applicability. For example, scRNA-seq data may contain dying cells, the presence of which, despite its biological significance is yet to be investigated, is commonly assumed to introduce additional biological errors. In addition, owing to the limitations of sequencing technology, issues may arise in which two or more cells are included in a sample well, resulting in over-expression levels of the assayed genes, thus compromising the quality of scRNA-seq data and potentially hindering the derivation oft meaningful biological results. Furthermore, interactions between genes usually exist in the form of networks, such as gene regulatory networks and biological signalling pathways. However, scGAA models do not explicitly introduce this valuable a priori knowledge. In contrast, the application of graph neural networks to biological networks, which can better model the interactions between genes because they aggregates information from neighbouring genes. This idea can be applied to scRNA-seq analyses, e.g. using scRNA-seq data to construct graph structures at the cellular level. From this perspective, graph transformer will probably be the main direction for future development of scGAA.

Methods

The scGAA model

The cell type annotation method for the scRNA-seq data based on the gated axial-attention mechanism (Fig .8) includes the involves steps:

Randomly divide the expression matrix including the top n highly variable genes of all cells from the gene expression matrix to obtain several gene sets, so that each gene set represents different feature information, and obtains the input tensor by embedding the gene set. Using the change weight matrix W (learned during the training of the transformer model) and mask matrix M (composed of 0 and 1 randomly and with the same dimension as W), multiply the corresponding positions of W and M, and then multiply them with the embedding representation G of each gene set to obtain the feature information t of each gene set. The formula is as follows:

$$\begin{aligned} t = W \cdot M \times G. \end{aligned}$$
(1)

where, \(\cdot\) is the point multiplication operation, \(\times\) is the matrix multiplication, and the mask matrix M is used.The feature information calculation of each gene set is repeated m times to increase the dimension of the space, and then the feature information obtained by m calculations is merged to obtain the input tensor T of each gene set. The formula is as follows:

$$\begin{aligned} T = concat(t_1, t_2, \ldots , t_m). \end{aligned}$$
(2)

In this context, the concat() function performs a merge operation, and the shape of the input tensor T is (NVHC), where N is the batch size, V is the height, H is the width, and C is the number of channels.

Dividing the gene sets into several sets is to convert genes into the smallest unit for model processing and generate input tensors. This input tensor is proven to be effective in obtaining gene-gene interactions in the following experiments.

In this study, the gene expression vectors of all cells are converted into a series of two-dimensional “images”, where each image contains \(H \times W\) “pixels” (each “pixel” corresponds to an element in the original vector). The input of self-attention is represented by matrix X, and the matrices Query, Key and Value are calculated by linear transformation matrices \(W_{Q}\), \(W_{K}\) and \(W_{V}\), and the output of Self-Attention can be calculated by obtaining the matrices Q, K, and V.

Furthermore, this study introduces an axial self-attention mechanism was introduced in the gated axial-attention operation module of scGAA. The axial attention mechanism is to divide the attention module into two modules: horizontal attention module (Horizontal-Attention) and vertical attention module (Vertical-Attention), and obtain two output matrices: horizontal attention output matrix (Row matrix) and vertical attention output matrix (Column matrix), and then merge them into one output matrix. The axial attention mechanism effectively simulates the original self-attention mechanism, greatly improves the computational efficiency, and performs self-attention operations in the horizontal and vertical directions respectively, aiming to effectively reduce the computational complexity. In addition, this mechanism can more effectively capture the interaction between genes, which helps to improve the model’s adaptability to different data sets.

The horizontal-attention output matrix is calculated in the horizontal attention module. First, the input tensor T is expanded along the row axis into a row input tensor \(T_{row}\), and each row is treated as an independent sequence, expressed as:

$$\begin{aligned} T_{row} = reshape(T, (N \cdot H, V, C)). \end{aligned}$$
(3)

In the context, reshape is the reshaping operation on the input tensor T, \(\cdot\) is the dot multiplication operation, N is the batch size, V is the height, H is the width, and C is the number of channels. Then calculate the query \(Q_{h}\), key \(K_{h}\) and value \(V_{h}\):

$$\begin{aligned} Q_{h} = W_{Qh} \times T_{row}, \end{aligned}$$
(4)
$$\begin{aligned} K_{h} = W_{Kh} \times T_{row}, \end{aligned}$$
(5)
$$\begin{aligned} V_{h} = W_{Vh} \times T_{row}. \end{aligned}$$
(6)

In the context, \(W_{Qh}\), \(W_{Kh}\) and \(W_{Vh}\) are linear transformation matrices, and \(\times\) is matrix multiplication. The horizontal attention similarity score matrix \(A_{h}\) is calculated for the query \(Q_{h}\), key \(K_{h}\) and value \(V_{h}\) respectively:

$$\begin{aligned} A_{h} = softmax\left( \frac{Q_{h}K_{h}^T}{\sqrt{d_h}}\right) . \end{aligned}$$
(7)

In the context, softmax function is the softmax activation function, \(\sqrt{d_h}\) is the dimension of the \(Q_{h}\) and \(K_{h}\), which is used as a scaling factor here to prevent the dot product value of \({{(Q_h})(K_h )^T}\) from being too large, and the superscript T is the matrix transpose. Then, calculate the horizontal-attention output matrix \(O_{h}\):

$$\begin{aligned} O_{h} = A_{h} \times V_{h}. \end{aligned}$$
(8)

Similarly, the vertical-attention output matrix \(O_{v}\) is calculated in the vertical attention module. First, the input tensor T is expanded along the column axis into a column input tensor \(T_{col}\), and each column is treated as an independent sequence, expressed as:

$$\begin{aligned} T_{col} = reshape(T, (N \cdot V, H, C)). \end{aligned}$$
(9)

Then calculate the query \(Q_{v}\), key \(K_{v}\) and value \(V_{v}\):

$$\begin{aligned} Q_{v} = W_{Qv} \times T_{col}, \end{aligned}$$
(10)
$$\begin{aligned} K_{v} = W_{Kv} \times T_{col}, \end{aligned}$$
(11)
$$\begin{aligned} V_{v} = W_{Vv} \times T_{col}. \end{aligned}$$
(12)

In the context, \(W_{Qv}\), \(W_{Kv}\) and \(W_{Vv}\) are linear transformation matrices, and \(\times\) is the matrix multiplication. The vertical-attention similarity score matrix \(A_{v}\) was calculated for the query \(Q_{v}\), key \(K_{v}\) and value \(V_{v}\) :

$$\begin{aligned} A_{v} = softmax\left( \frac{Q_{v}K_{v}^T}{\sqrt{d_v}}\right) . \end{aligned}$$
(13)

In the context, softmax function is the softmax activation function, \(\sqrt{d_v}\) is the dimension of the \(Q_{v}\) and \(K_{v}\), which is used as a scaling factor here to prevent the dot product value of \({{(Q_v})(K_v )^T}\) from being too large, and the superscript T is the matrix transpose. Then, calculate the vertical-attention output matrix \(O_{v}\):

$$\begin{aligned} O_{v} = A_{v} \times V_{v}. \end{aligned}$$
(14)

On this basis, to further optimize the performance of the scGAA model, six gating units \(G_{Qh}\), \(G_{Kh}\), \(G_{Vh}\), \(G_{Qv}\), \(G_{Kv}\) and \(G_{Vv}\) were introduced to control the information of Q, K and V in horizontal-attention and vertical-attention respectively. These gating units are learnable parameters for extracting important features. Depending on whether the learned information is useful, the gating unit generates a set of weights close to 0 to 1. This weight can be used to control the amount of information passed through, to extract the most important features and improve the prediction accuracy of the model.

\(G_{Qh}\), \(G_{Kh}\) and \(G_{Vh}\) are added to \(Q_{h}\), \(K_{h}\) and \(V_{h}\) respectively, and the horizontal-attention similarity score matrix \(A_{h}\) is calculated as follows:

$$\begin{aligned} A_h = softmax \left( \frac{(Q_h \cdot G_{Qh})(K_h \cdot G_{Kh})^T}{\sqrt{d_h}} \right) . \end{aligned}$$
(15)

Finally, the horizontal-attention output matrix \(O_{h}\) is calculated:

$$\begin{aligned} O_h = A_h \times (V_h \cdot G_{Vh}). \end{aligned}$$
(16)

Similarly, \(G_{Qv}\), \(G_{Kv}\) and \(G_{Vv}\) are added to \(Q_{v}\), \(K_{v}\) and \(V_{v}\) respectively, and the vertical-attention similarity score matrix \(A_{v}\) is calculated as:

$$\begin{aligned} A_v = softmax \left( \frac{(Q_v \cdot G_{Qv})(K_v \cdot G_{Kv})^T}{\sqrt{d_v}} \right) . \end{aligned}$$
(17)

Finally, the vertical-attention output matrix \(O_{v}\) is calculated:

$$\begin{aligned} O_v = A_v \times (V_v \cdot G_{Vv}). \end{aligned}$$
(18)

The outputs of the horizontal-attention and vertical-attention are combined to better integrate the global information. The output matrix O is obtained, and the formula is expressed as:

$$\begin{aligned} O = O_h + O_v . \end{aligned}$$
(19)
Fig. 8
figure 8

Gated axial-attention detailed structure. Firstly, this model implements feature embedding for all gene sets, using horizontal and vertical gated attention mechanisms, respectively. By calculating the horizontal (\(Q_{h} \times K_{h}\)) and vertical (\(Q_{v} \times K_{v}\)) attention scores, the model is able to identify key gene sets. With the help of these key gene sets, we can conduct further studies such as difference analysis, enrichment analysis, and gene characterisation, which provide the basis for subsequent analysis and interpretation. Subsequently, the model fuses the outputs of horizontal and vertical gated attention and calculates scores for each category through a linear layer. Ultimately, these scores were converted into corresponding probabilities via a softmax function.

Data preprocessing

Since most of the scRNA-seq is not perfect, we need to perform quality control on the data to filter low quality cells. In this context, cells with less than 3 genes expressed and cells with less than 200 genes expressed will be filtered out. Next the data is normalised, in this paper Pearson’s approximation of residuals is used which preserves the intercellular variation including biological heterogeneity and helps to identify rare cell types49. In this paper, we use the preprocessing function pp.preprocess provided in scanpy50, which can directly calculate the Pearson residuals for normalisation51. First, the median of the sum of gene expression for all cells and the sum of gene expression per cell were calculated. Normalise the expression of each gene \(X_{ij}\) for each cell:

$$\begin{aligned} X_{ij}^{{norm}} = X_{ij} \times \frac{S_{median}}{S_i}. \end{aligned}$$
(20)

where \(X_{ij}\) is the raw expression of gene j in cell i. \(S_{i}\) is the sum of the expression of all the genes in cell i. \(S_{median}\) is the median of the sum of gene expression in all cells. \(X_{ij}^{{norm}}\) is the normalised gene expression.Afterwards, the normalised gene expression was log-transformed:

$$\begin{aligned} X_{ij}^{{log}} = \log (1 + X_{ij}^{{norm}}). \end{aligned}$$
(21)

\(X_{ij}^{{log}}\) is the final log-transformed gene expression. This two-step process made the data comparable across cells and helped reduce the impact of extreme values in the data on subsequent analyses.

After that, the normalised data is used as input to the scGAA model to train the model, in order to prevent one category in the dataset being much larger than the others, then the model may be biased towards predicting the one with the highest number of samples, which will lead to poorer performance in predicting categories with fewer samples. Balancing the dataset by category allows the model to be exposed to a more balanced set of samples during training, which helps the model learn the differences between the categories better and improves the model’s prediction performance on unknown data. Then 80% of the dataset is randomly split into the training and test sets, and the remaining data is divided into the validation set.

Loss function

The model training process includes two stages: self-supervised learning on unlabeled data to obtain a pre-trained model; supervised learning on specific cell type labeling tasks to obtain a fine-tuned model. Use this to predict the probability of each cell type. Cross-entropy loss is also used as cell type label prediction loss, and the loss function is calculated as follows:

$$\begin{aligned} L = -\sum _{i=1}^{N} \sum _{c=1}^{C} y_{i,c} log(p_{i,c}). \end{aligned}$$
(22)

where L represents the value of the loss function, N represents the total number of cells in the dataset, C represents the total number of cell types, \(y_{i,c}\) is an indicator that is 1 when the true category of cell i is c and 0 otherwise, and \(p_{i,c}\) is the probability that the model predicts that sample i belongs to category c.

This formula quantifies the accuracy of the model’s prediction by calculating the logarithm of the predicted probability corresponding to the actual cell type and taking its negative value. As the model’s predicted probability gets closer to the actual category (i.e., \(p_{i,c}\) is close to 1 when \(y_{i,c}\) = 1), the cross-entropy loss is smaller, and vice versa.

We use Stochastic Gradient Descent (SGD) as the optimization algorithm and employ a cosine learning rate decay strategy during the training process to prevent issues caused by large steps in the late stages of training. The accuracy and Macro F1 Score metrics are used to evaluate the performance of each method in cell type annotation at both the cell level and cell type level.