configs

2024-06-18 03:44
文章标签 configs

本文主要是介绍configs,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

configs 部分

```python
import os  # 导入os模块,用于系统级操作

emotion = ["Valence"]  # 定义情绪列表,只包含情绪维度"Valence"

# 配置参数字典
config = {
    "extract_class_label": 1,  # 是否提取类别标签
    "extract_continuous_label": 1,  # 是否提取连续标签
    "extract_eeg": 1,  # 是否提取EEG数据
    "eeg_folder": "eeg",  # 存放EEG数据的文件夹名称
    "eeg_config": {  # EEG数据处理的详细配置
        "sampling_frequency": 256,  # 采样频率
        "window_sec": 2,  # 窗口长度(秒)
        "hop_sec": 0.25,  # 跳跃长度(秒)
        "buffer_sec": 5,  # 缓冲区长度(秒)
        "num_electrodes": 32,  # 电极数量
        "interest_bands": [(0.3, 4), (4, 8), (8, 12), (12, 18), (18, 30), (30, 45)],  # 感兴趣频段
        "f_trans_interest_bands": [(0.1, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)],  # 感兴趣频段的过渡频率
        "channel_slice": {'eeg': slice(0, 32), 'ecg': slice(32, 35), 'misc': slice(35, -1)},  # 通道切片
        "features": ["eeg_bandpower"],  # 特征
        "filter_type": 'cheby2',  # 滤波器类型
        "filter_order": 4  # 滤波器阶数
    },
    "save_npy": 1,  # 是否保存为.npy格式的数据
    "npy_folder": "compacted_48",  # 存放.npy数据的文件夹名称
    "dataset_name": "mahnob",  # 数据集的名称
    "emotion_list": emotion,  # 情绪列表
    "root_directory": r"D:\DingYi\Dataset\MAHNOB-O",  # 原始数据集的根目录路径
    "output_root_directory": r"D:\DingYi\Dataset\MAHNOB-P-R",  # 处理后数据的输出根目录路径
    "raw_data_folder": "Sessions",  # 原始数据存放的文件夹名称
    "multiplier": {  # 不同数据类型的倍增因子
        "video": 16,
        "eeg_raw": 1,
        "eeg_bandpower": 1,
        "eeg_DE": 1,
        "eeg_RP": 1,
        "eeg_Hjorth": 1,
        "continuous_label": 1
    },
    "feature_dimension": {  # 不同特征的维度信息
        "eeg_raw": (16384,),
        "eeg_bandpower": (192,),
        "eeg_DE": (192,),
        "eeg_RP": (192,),
        "eeg_Hjorth": (96,),
        "continuous_label": (1,),
        "class_label": (1,)
    },
    "max_epoch": 15,  # 最大的训练周期数
    "min_epoch": 0,  # 最小的训练周期数
    "model_name": "2d1d",  # 模型的名称
    "backbone": {  # 模型的骨干网络配置
        "state_dict": "res50_ir_0.887",
        "mode": "ir"
    },
    "early_stopping": 10,  # 提前停止训练的步数
    "load_best_at_each_epoch": 1,  # 是否在每个周期加载最佳模型
    "time_delay": 0,  # 时间延迟
    "metrics": ["rmse", "pcc", "ccc"],  # 评估指标
    "save_plot": 0  # 是否保存图形结果
}
```

这段代码是一个Python字典,包含了各种配置参数,用于处理和分析一个名为MAHNOB的数据集,主要用于情绪识别研究。以下是每行代码的解释:

1. `import os`: 导入Python的os模块,用于操作文件路径等系统级操作。

2. `emotion = ["Valence"]`: 定义一个情绪列表,只包含情绪维度"Valence"。

3. `config = { ... }`: 定义一个名为config的字典,包含了各种配置参数。

4. `"extract_class_label": 1`: 是否提取类别标签,这里设为1表示是。

5. `"extract_continuous_label": 1`: 是否提取连续标签,这里设为1表示是。

6. `"extract_eeg": 1`: 是否提取EEG数据,这里设为1表示是。

7. `"eeg_folder": "eeg"`: 存放EEG数据的文件夹名称。

8. `"eeg_config": { ... }`: EEG数据处理的详细配置,包括采样频率、窗口长度、跳跃长度、通道数量等参数。

9. `"save_npy": 1`: 是否保存为.npy格式的数据,这里设为1表示是。

10. `"npy_folder": "compacted_48"`: 存放.npy数据的文件夹名称。

11. `"dataset_name": "mahnob"`: 数据集的名称。

12. `"emotion_list": emotion`: 情绪列表,使用了之前定义的emotion变量。

13. `"root_directory": r"D:\DingYi\Dataset\MAHNOB-O"`: 原始数据集的根目录路径。

14. `"output_root_directory": r"D:\DingYi\Dataset\MAHNOB-P-R"`: 处理后数据的输出根目录路径。

15. `"raw_data_folder": "Sessions"`: 原始数据存放的文件夹名称。

16. `"multiplier": { ... }`: 不同数据类型的倍增因子,用于数据增强或者调整数据量。

17. `"feature_dimension": { ... }`: 不同特征的维度信息,用于数据处理和模型输入。

18. `"max_epoch": 15`: 最大的训练周期数。

19. `"min_epoch": 0`: 最小的训练周期数。

20. `"model_name": "2d1d"`: 模型的名称,这里只是命名用途,实际上没有使用。

21. `"backbone": { ... }`: 模型的骨干网络配置,包括状态字典和模式。

22. `"early_stopping": 10`: 提前停止训练的步数。

23. `"load_best_at_each_epoch": 1`: 是否在每个周期加载最佳模型。

24. `"time_delay": 0`: 时间延迟,用于连续标签在数据点中的移动。

25. `"metrics": ["rmse", "pcc", "ccc"]`: 评估指标,包括均方根误差、皮尔逊相关系数和一致性相关系数。

26. `"save_plot": 0`: 是否保存图形结果,这里设为0表示否。

这些配置参数用于设置数据预处理、模型训练和评估过程中的各种选项和参数,确保流程能够顺利进行和有效执行。

from base.preprocessing import GenericDataPreprocessing  # 导入基础数据预处理类
from base.utils import expand_index_by_multiplier, load_pickle, save_to_pickle, get_filename_from_a_folder_given_extension, ensure_dir  # 导入一些辅助函数和工具
from base.label_config import *  # 导入标签配置

import os  # 导入os模块,用于系统级操作
import scipy.io as sio  # 导入scipy.io模块,用于读取.mat文件

import pandas as pd  # 导入pandas库,用于数据处理和分析
import numpy as np  # 导入numpy库,用于数值计算

import xml.etree.ElementTree as et  # 导入xml.etree.ElementTree模块,用于解析XML文件

generate_dataset.py

from base.preprocessing import GenericDataPreprocessing  # 导入基础数据预处理类
from base.utils import expand_index_by_multiplier, load_pickle, save_to_pickle, get_filename_from_a_folder_given_extension, ensure_dir  # 导入一些辅助函数和工具
from base.label_config import *  # 导入标签配置

import os  # 导入os模块,用于系统级操作
import scipy.io as sio  # 导入scipy.io模块,用于读取.mat文件

import pandas as pd  # 导入pandas库,用于数据处理和分析
import numpy as np  # 导入numpy库,用于数值计算

import xml.etree.ElementTree as et  # 导入xml.etree.ElementTree模块,用于解析XML文件


class Preprocessing(GenericDataPreprocessing):
    def __init__(self, config):
        super().__init__(config)

    def generate_iterator(self):
        # 生成迭代器,返回按照文件名排序的文件路径列表
        path = os.path.join(self.config['root_directory'], self.config['raw_data_folder'])
        iterator = [os.path.join(path, file) for file in sorted(os.listdir(path), key=float)]
        return iterator

    def generate_per_trial_info_dict(self):
        # 生成每个试验的信息字典
        per_trial_info_path = os.path.join(self.config['output_root_directory'], "processing_records.pkl")
        if os.path.isfile(per_trial_info_path):
            per_trial_info = load_pickle(per_trial_info_path)
        else:
            per_trial_info = {}
            pointer = 0

            sub_trial_having_continuous_label = self.get_sub_trial_info_for_continuously_labeled()
            all_continuous_labels = self.read_all_continuous_label()

            iterator = self.generate_iterator()

            for idx, file in enumerate(iterator):
                kwargs = {}
                this_trial = {}
                print(file)

                time_stamp_file = get_filename_from_a_folder_given_extension(file, "tsv", "All-Data")[0]
                video_trim_range = self.read_start_end_from_mahnob_tsv(time_stamp_file)
                if video_trim_range is not None:
                    this_trial['video_trim_range'] = self.read_start_end_from_mahnob_tsv(time_stamp_file)
                else:
                    this_trial['discard'] = 1
                    continue

                this_trial['has_continuous_label'] = 0
                session = int(file.split(os.sep)[-1])
                subject_no, trial_no = session // 130 + 1, session % 130

                if subject_no == sub_trial_having_continuous_label[pointer][0] and trial_no == sub_trial_having_continuous_label[pointer][1]:
                    this_trial['has_continuous_label'] = 1

                this_trial['continuous_label'] = None
                this_trial['annotated_index'] = None
                annotated_index = np.arange(this_trial['video_trim_range'][0][1])
                if this_trial['has_continuous_label']:
                    raw_continuous_label = all_continuous_labels[pointer]
                    this_trial['continuous_label'] = raw_continuous_label
                    annotated_index = self.process_continuous_label(raw_continuous_label)
                    this_trial['annotated_index'] = annotated_index
                    pointer += 1

                this_trial['has_eeg'] = 1
                eeg_path = get_filename_from_a_folder_given_extension(file, "bdf")
                if len(eeg_path) == 1:
                    this_trial['eeg_path'] = eeg_path[0].split(os.sep)
                else:
                    this_trial['eeg_path'] = None
                    this_trial['has_eeg'] = 0

                this_trial['audio_path'] = ""

                this_trial['subject_no'] = subject_no
                this_trial['trial_no'] = trial_no
                this_trial['trial'] = "P{}-T{}".format(str(subject_no), str(trial_no))

                this_trial['target_fps'] = 64

                kwargs['feature'] = "video"
                kwargs['has_continuous_label'] = this_trial['has_continuous_label']
                this_trial['video_annotated_index'] = self.get_annotated_index(annotated_index, **kwargs)

                this_trial['class_label'] = get_filename_from_a_folder_given_extension(file, "xml")[0]
                per_trial_info[idx] = this_trial

        ensure_dir(per_trial_info_path)
        save_to_pickle(per_trial_info_path, per_trial_info)
        self.per_trial_info = per_trial_info

    def generate_dataset_info(self):
        # 生成数据集信息
        class_label = {}
        for idx, record in self.per_trial_info.items():
            self.dataset_info['trial'].append(record['processing_record']['trial'])
            self.dataset_info['trial_no'].append(record['trial_no'])
            self.dataset_info['subject_no'].append(record['subject_no'])
            self.dataset_info['has_continuous_label'].append(record['has_continuous_label'])
            self.dataset_info['has_eeg'].append(record['has_eeg'])

            if record['has_continuous_label']:
                self.dataset_info['length'].append(len(record['continuous_label']))
            else:
                self.dataset_info['length'].append(len(record['video_annotated_index']) // 16)

            if self.config['extract_class_label']:
                class_label.update({record['processing_record']['trial']: self.extract_class_label_fn(record)})

        self.dataset_info['multiplier'] = self.config['multiplier']
        self.dataset_info['data_folder'] = self.config['npy_folder']

        path = os.path.join(self.config['output_root_directory'], 'dataset_info.pkl')
        save_to_pickle(path, self.dataset_info)

        if self.config['extract_class_label']:
            path = os.path.join(self.config['output_root_directory'], 'class_label.pkl')
            save_to_pickle(path, class_label)

    def extract_class_label_fn(self, record):
        # 提取类别标签
        class_label = {}
        if record['has_eeg']:
            xml_file = et.parse(record['class_label']).getroot()
            felt_emotion = xml_file.find('.').attrib['feltEmo']
            felt_arousal = xml_file.find('.').attrib['feltArsl']
            felt_valence = xml_file.find('.').attrib['feltVlnc']

            arousal = 0 if float(felt_arousal) <= 5 else 1
            valence = 0 if float(felt_valence) <= 5 else 1

            class_label = {
                "Arousal": arousal,
                "Valence": valence,
                "Arousal_3cls": arousal_class_to_number[emotion_tag_to_arousal_class[number_to_emotion_tag_dict[felt_emotion]]],
                "Valence_3cls": valence_class_to_number[emotion_tag_to_valence_class[number_to_emotion_tag_dict[felt_emotion]]]
            }

        return class_label

    def extract_continuous_label_fn(self, idx, npy_folder):
        # 提取连续标签
        if self.per_trial_info[idx]["has_continuous_label"]:
            raw_continuous_label = self.per_trial_info[idx]['continuous_label']

            if self.config['save_npy']:
                filename = os.path.join(npy_folder, "continuous_label.npy")
                if not os.path.isfile(filename):
                    ensure_dir(filename)
                    np.save(filename, raw_continuous_label)

    def load_continuous_label(self, path, **kwargs):
        # 加载连续标签
        cols = [emotion.lower() for emotion in self.config['emotion_list']]

        if os.path.isfile(path):
            continuous_label = pd.read_csv(path, sep=";",
                                           skipinitialspace=True, usecols=cols,
                                           index_col=False).values.squeeze()
        else:
            continuous_label = 0

        return continuous_label

    def get_annotated_index(self, annotated_index, **kwargs):
        # 获取标注索引
        feature = kwargs['feature']
        multiplier = self.config['multiplier'][feature]

        if kwargs['has_continuous_label']:
            annotated_index = expand_index_by_multiplier(annotated_index, multiplier)
        else:
            pass

        return annotated_index

    def get_sub_trial_info_for_continuously_labeled(self):
        # 获取连续标签的子试验信息
        label_file = os.path.join(self.config['root_directory'], "lable_continous_Mahnob.mat")
        mat_content = sio.loadmat(label_file)
        sub_trial_having_continuous_label = mat_content['trials_included']

        return sub_trial_having_continuous_label

    @staticmethod
    def read_start_end_from_mahnob_tsv(tsv_file):
        # 从Mahnob的tsv文件中读取起始和结束时间
        if os.path.isfile(tsv_file):
            data = pd.read_csv(tsv_file, sep='\t', skiprows=23)
            end = data[data['Event'] == 'MovieEnd'].index[0]
            start_end = [(0, end)]
        else:
            start_end = None
        return start_end

    def read_all_continuous_label(self):
        # 读取所有连续标签
        label_file = os.path.join(self.config['root_directory'], "lable_continous_Mahnob.mat")
        mat_content = sio.loadmat(label_file)
        annotation_cell = np.squeeze(mat_content['labels'])

        label_list = []
        for index in range(len(annotation_cell)):
            label_list.append(annotation_cell[index].T)
        return label_list

    @staticmethod
    def init_dataset_info():
        # 初始化数据集信息
        dataset_info = {
            "trial": [],
            "subject_no": [],
            "trial_no": [],
            "length": [],
            "has_continuous_label": [],
            "has_eeg": [],
        }
        return dataset_info


if __name__ == "__main__":
    from configs import config

    pre = Preprocessing(config)
    pre.generate_per_trial_info_dict()
    pre.prepare_data()

这段代码定义了一个名为Preprocessing的类,继承自GenericDataPreprocessing类,用于数据预处理。它包含了一些方法和函数,用于生成每个试验的信息字典、生成数据集信息、提取类别标签、提取连续标签等操作。在if __name__ == "__main__":部分,创建了Preprocessing对象,并调用了相关方法进行数据预处理。

main.py

from base.preprocessing import GenericDataPreprocessing  # 导入自定义的GenericDataPreprocessing类
from base.utils import expand_index_by_multiplier, load_pickle, save_to_pickle, get_filename_from_a_folder_given_extension, ensure_dir  # 导入一些辅助函数和工具
from base.label_config import *  # 导入标签配置

import os  # 导入os模块,用于文件和目录操作
import scipy.io as sio  # 导入scipy.io模块,用于读取MATLAB文件

import pandas as pd  # 导入pandas库,用于数据处理
import numpy as np  # 导入numpy库,用于数值计算

import xml.etree.ElementTree as et  # 导入xml.etree.ElementTree模块,用于解析XML文件


class Preprocessing(GenericDataPreprocessing):
    def __init__(self, config):
        super().__init__(config)

    def generate_iterator(self):
        path = os.path.join(self.config['root_directory'], self.config['raw_data_folder'])
        iterator = [os.path.join(path, file) for file in sorted(os.listdir(path), key=float)]
        return iterator

    def generate_per_trial_info_dict(self):
        # 生成每个试验的信息字典

        per_trial_info_path = os.path.join(self.config['output_root_directory'], "processing_records.pkl")
        if os.path.isfile(per_trial_info_path):
            per_trial_info = load_pickle(per_trial_info_path)
        else:
            per_trial_info = {}
            pointer = 0

            sub_trial_having_continuous_label = self.get_sub_trial_info_for_continuously_labeled()
            all_continuous_labels = self.read_all_continuous_label()

            iterator = self.generate_iterator()

            for idx, file in enumerate(iterator):
                kwargs = {}
                this_trial = {}
                print(file)

                time_stamp_file = get_filename_from_a_folder_given_extension(file, "tsv", "All-Data")[0]
                video_trim_range = self.read_start_end_from_mahnob_tsv(time_stamp_file)
                if video_trim_range is not None:
                    this_trial['video_trim_range'] = self.read_start_end_from_mahnob_tsv(time_stamp_file)
                else:
                    this_trial['discard'] = 1
                    continue

                this_trial['has_continuous_label'] = 0
                session = int(file.split(os.sep)[-1])
                subject_no, trial_no = session // 130 + 1, session % 130

                if subject_no == sub_trial_having_continuous_label[pointer][0] and trial_no == sub_trial_having_continuous_label[pointer][1]:
                    this_trial['has_continuous_label'] = 1

                this_trial['continuous_label'] = None
                this_trial['annotated_index'] = None
                annotated_index = np.arange(this_trial['video_trim_range'][0][1])
                if this_trial['has_continuous_label']:
                    raw_continuous_label = all_continuous_labels[pointer]
                    this_trial['continuous_label'] = raw_continuous_label
                    annotated_index = self.process_continuous_label(raw_continuous_label)
                    this_trial['annotated_index'] = annotated_index
                    pointer += 1

                this_trial['has_eeg'] =  1
                eeg_path = get_filename_from_a_folder_given_extension(file, "bdf")
                if len(eeg_path) == 1:
                    this_trial['eeg_path'] = eeg_path[0].split(os.sep)
                else:
                    this_trial['eeg_path'] = None
                    this_trial['has_eeg'] = 0

                this_trial['audio_path'] = ""

                this_trial['subject_no'] = subject_no
                this_trial['trial_no'] = trial_no
                this_trial['trial'] = "P{}-T{}".format(str(subject_no), str(trial_no))

                this_trial['target_fps'] = 64

                kwargs['feature'] = "video"
                kwargs['has_continuous_label'] = this_trial['has_continuous_label']
                this_trial['video_annotated_index'] = self.get_annotated_index(annotated_index, **kwargs)

                this_trial['class_label'] = get_filename_from_a_folder_given_extension(file, "xml")[0]
                per_trial_info[idx] = this_trial

        ensure_dir(per_trial_info_path)
        save_to_pickle(per_trial_info_path, per_trial_info)
        self.per_trial_info = per_trial_info

    def generate_dataset_info(self):
        # 生成数据集信息

        class_label = {}
        for idx, record in self.per_trial_info.items():
            self.dataset_info['trial'].append(record['processing_record']['trial'])
            self.dataset_info['trial_no'].append(record['trial_no'])
            self.dataset_info['subject_no'].append(record['subject_no'])
            self.dataset_info['has_continuous_label'].append(record['has_continuous_label'])
            self.dataset_info['has_eeg'].append(record['has_eeg'])

            if record['has_continuous_label']:
                self.dataset_info['length'].append(len(record['continuous_label']))
            else:
                self.dataset_info['length'].append(len(record['video_annotated_index']) // 16)

            if self.config['extract_class_label']:
                class_label.update({record['processing_record']['trial']: self.extract_class_label_fn(record)})

        self.dataset_info['multiplier'] = self.config['multiplier']
        self.dataset_info['data_folder'] = self.config['npy_folder']

        path = os.path.join(self.config['output_root_directory'], 'dataset_info.pkl')
        save_to_pickle(path, self.dataset_info)

        if self.config['extract_class_label']:
            path = os.path.join(self.config['output_root_directory'], 'class_label.pkl')
            save_to_pickle(path, class_label)

    def extract_class_label_fn(self, record):
        # 提取类别标签的函数

        class_label = {}
        if record['has_eeg']:
            xml_file = et.parse(record['class_label']).getroot()
            felt_emotion = xml_file.find('.').attrib['feltEmo']
            felt_arousal = xml_file.find('.').attrib['feltArsl']
            felt_valence = xml_file.find('.').attrib['feltVlnc']

            arousal = 0 if float(felt_arousal) <= 5 else 1
            valence = 0 if float(felt_valence) <= 5 else 1

            class_label = {
                "Arousal": arousal,
                "Valence": valence,
                "Arousal_3cls": arousal_class_to_number[emotion_tag_to_arousal_class[number_to_emotion_tag_dict[felt_emotion]]],
                "Valence_3cls": valence_class_to_number[emotion_tag_to_valence_class[number_to_emotion_tag_dict[felt_emotion]]]
            }

        return class_label

    def extract_continuous_label_fn(self, idx, npy_folder):
        # 提取连续标签的函数

        if self.per_trial_info[idx]["has_continuous_label"]:
            raw_continuous_label = self.per_trial_info[idx]['continuous_label']

            if self.config['save_npy']:
                filename = os.path.join(npy_folder, "continuous_label.npy")
                if not os.path.isfile(filename):
                    ensure_dir(filename)
                    np.save(filename, raw_continuous_label)

    def load_continuous_label(self, path, **kwargs):
        # 加载连续标签

        cols = [emotion.lower() for emotion in self.config['emotion_list']]

        if os.path.isfile(path):
            continuous_label = pd.read_csv(path, sep=";",
                                           skipinitialspace=True, usecols=cols,
                                           index_col=False).values.squeeze()
        else:
            continuous_label = 0

        return continuous_label

    def get_annotated_index(self, annotated_index, **kwargs):
        # 获取标注索引

        feature = kwargs['feature']
        multiplier = self.config['multiplier'][feature]

        if kwargs['has_continuous_label']:
            annotated_index = expand_index_by_multiplier(annotated_index, multiplier)
        else:
            pass

        return annotated_index

    def get_sub_trial_info_for_continuously_labeled(self):
        # 获取具有连续标签的子试验信息

        label_file = os.path.join(self.config['root_directory'], "lable_continous_Mahnob.mat")
        mat_content = sio.loadmat(label_file)
        sub_trial_having_continuous_label = mat_content['trials_included']

        return sub_trial_having_continuous_label

    @staticmethod
    def read_start_end_from_mahnob_tsv(tsv_file):
        # 从Mahnob TSV文件中读取起始和结束时间

        if os.path.isfile(tsv_file):
            data = pd.read_csv(tsv_file, sep='\t', skiprows=23)
            end = data[data['Event'] == 'MovieEnd'].index[0]
            start_end = [(0, end)]
        else:
            start_end = None
        return start_end

    def read_all_continuous_label(self):
        # 读取所有连续标签

        label_file = os.path.join(self.config['root_directory'], "lable_continous_Mahnob.mat")
        mat_content = sio.loadmat(label_file)
        annotation_cell = np.squeeze(mat_content['labels'])

        label_list = []
        for index in range(len(annotation_cell)):
            label_list.append(annotation_cell[index].T)
        return label_list

    @staticmethod
    def init_dataset_info():
        # 初始化数据集信息

        dataset_info = {
            "trial": [],
            "subject_no": [],
            "trial_no": [],
            "length": [],
            "has_continuous_label": [],
            "has_eeg": [],
        }
        return dataset_info


if __name__ == "__main__":
    from configs import config  # 导入配置文件

    pre = Preprocessing(config)  # 创建Preprocessing对象,传入配置文件
    pre.generate_per_trial_info_dict()  # 生成每个试验的信息字典
    pre.prepare_data()  # 准备数据
这段代码是一个数据预处理的类Preprocessing,继承自GenericDataPreprocessing。它包含了一些方法用于生成每个试验的信息字典、生成数据集信息、提取类别标签和连续标签等操作。在__main__函数中,创建了一个Preprocessing对象,并调用了相关方法进行数据预处理。

加入其他数据集中

```python
from model import MASA_TCN  # 从model模块中导入MASA_TCN模型

data = torch.randn(1, 1, 192, 96)  # 生成一个随机张量作为输入数据,形状为(batch_size=1, cnn_channel=1, EEG_channel*feature=32*6, data_sequence=96)

# 对于回归任务,输出形状为(batch_size, data_sequence, 1)。
net = MASA_TCN(
        cnn1d_channels=[128, 128, 128],  # 1维卷积层的通道数列表
        cnn1d_kernel_size=[3, 5, 15],  # 1维卷积层的核大小列表
        cnn1d_dropout_rate=0.1,  # 1维卷积层的dropout率
        num_eeg_chan=32,  # EEG通道数
        freq=6,  # 特征频率
        output_dim=1,  # 输出维度
        early_fusion=True,  # 是否使用早期融合
        model_type='reg')  # 模型类型为回归
preds = net(data)  # 对输入数据进行预测

# 对于分类任务,输出形状为(batch_size, num_classes)。注意:output_dim应该是类别的数量。
net = MASA_TCN(
        cnn1d_channels=[128, 128, 128],  # 1维卷积层的通道数列表
        cnn1d_kernel_size=[3, 5, 15],  # 1维卷积层的核大小列表
        cnn1d_dropout_rate=0.1,  # 1维卷积层的dropout率
        num_eeg_chan=32,  # EEG通道数
        freq=6,  # 特征频率
        output_dim=2,  # 输出维度
        early_fusion=True,  # 是否使用早期融合
        model_type='cls')  # 模型类型为分类
preds = net(data)  # 对输入数据进行预测
```

这段代码首先导入了MASA_TCN模型,然后创建了一个随机输入数据,并使用MASA_TCN模型进行了回归和分类任务的预测。注释已经在代码中添加了。

这篇关于configs的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1071218

相关文章

promethus 的 relabel_configs 和 metric_relabel_configs

很多童鞋在群里面反馈 relabel_configs 和 metric_relabel_configs 两个配置使用区别。都是relabel 譬如relabel_configs的relabel如下: - source_labels: [__meta_kubernetes_pod_annotation_prometheus_io_scrape]separator: ;regex: "true

extjs对象的configs和properties有什么区别?

转自:https://zhidao.baidu.com/question/497243792.html 区别在于两者出现的时间,config是create时的参数,不一定全部都赋值;properties是属性,是create后生成的,但有些config能直接从对象中获得,所以容易混淆

uboot移植之uboot/include/configs/mini2440.h

此文件是设置uboot的一些参数的主要地方,比较常用的用/***&&&****/标记了一下 /** (C) Copyright 2002* Sysgo Real-Time Solutions, GmbH <www.elinos.com>* Marius Groeger <mgroeger@sysgo.de>* Gary Jennejohn <garyj@denx.de>* David Mue

python爬取微博热门消息(二)—— configs中参数的设置及程序执行过程

这一节,主要讲述配置文件configs.py中参数的含义,以及cookie的获取方式。 感兴趣的小伙伴可以 收藏 + 关注 哦! 另外,关于本项目的效果展示,以及教程,点击一下链接即可。 python爬取微博热门消息(一)——效果展示 python爬取微博热门消息(三)—— 爬取微博热门信息的功能函数 python爬取微博热门消息(四)—— 完整代码 目录 一、常用参数 二