lit-llama代码解析

2024-09-04 05:52
文章标签 代码 解析 llama lit

本文主要是介绍lit-llama代码解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

https://github.com/Lightning-AI/lit-llama/blob/main/README.md

下载的时候会报错误,因为网不行,一种方法就是多次尝试,另一种方法是终端连上代理下载

pycharm连接hugging face等网站_hugging face怎么连接-CSDN博客

根据指引下载权重

下载完权重运行:python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/open-llama/7B --model_size 7B

转化为.pth文件 

跟着readme/howto教程量化或进行其他操作

warning

UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at ..\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:455.)
  y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

https://github.com/comfyanonymous/ComfyUI/issues/3202

分析generate

# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.import sys
import time
import warnings
from pathlib import Path
from typing import Optionalimport lightning as L
import torch
print(torch.cuda.is_available())
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import lazy_load, llama_model_lookup, quantization@torch.no_grad()
def generate(model: LLaMA,idx: torch.Tensor,max_new_tokens: int,*,max_seq_length: Optional[int] = None,temperature: float = 1.0,top_k: Optional[int] = None,eos_id: Optional[int] = None,
) -> torch.Tensor:"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.The implementation of this function is modified from A. Karpathy's nanoGPT.Args:model: The model to use.idx: Tensor of shape (T) with indices of the prompt sequence.max_new_tokens: The number of new tokens to generate.max_seq_length: The maximum sequence length allowed.temperature: Scales the predicted logits by 1 / temperaturetop_k: If specified, only sample among the tokens with the k highest probabilitieseos_id: If specified, stop generating any more token once the <eos> token is triggered"""# create an empty tensor of the expected final shape and fill in the current tokensT = idx.size(0)T_new = T + max_new_tokensif max_seq_length is None:max_seq_length = min(T_new, model.config.block_size)device, dtype = idx.device, idx.dtype# create an empty tensor of the expected final shape and fill in the current tokensempty = torch.empty(T_new, dtype=dtype, device=device)empty[:T] = idxidx = emptyinput_pos = torch.arange(0, T, device=device)if idx.device.type == "xla":import torch_xla.core.xla_model as xmxm.mark_step()# generate max_new_tokens tokensfor _ in range(max_new_tokens):x = idx.index_select(0, input_pos).view(1, -1)# forwardlogits = model(x, max_seq_length, input_pos)logits = logits[0, -1] / temperature# optionally crop the logits to only the top k optionsif top_k is not None:v, _ = torch.topk(logits, min(top_k, logits.size(-1)))logits = torch.where(logits < v[[-1]], -float("Inf"), logits)probs = torch.nn.functional.softmax(logits, dim=-1)idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)# advanceinput_pos = input_pos[-1:] + 1if idx.device.type == "xla":xm.mark_step()# concatenate the new generationidx = idx.index_copy(0, input_pos, idx_next)# if <eos> token is triggered, return the output (stop generation)if idx_next == eos_id:return idx[:input_pos]  # include the EOS tokenreturn idxdef main(prompt: str = "Hello, my name is",*,num_samples: int = 1,max_new_tokens: int = 50,top_k: int = 200,temperature: float = 0.8,checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),quantize: Optional[str] = None,
) -> None:"""Generates text samples based on a pre-trained LLaMA model and tokenizer.Args:prompt: The prompt string to use for generating the samples.num_samples: The number of text samples to generate.(Its effect is overridden by `max_new_tokens`, if also set.)max_new_tokens: The number of generation steps to take.(number of generate tokens )top_k: The number of top most probable tokens to consider in the sampling process.temperature: A value controlling the randomness of the sampling process. Higher values result in more randomsamples.checkpoint_path: The checkpoint path to load.tokenizer_path: The tokenizer path to load.quantize: Whether to quantize the model and using which method:``"llm.int8"``: LLM.int8() mode,``"gptq.int4"``: GPTQ 4-bit mode."""assert checkpoint_path.is_file(), checkpoint_pathassert tokenizer_path.is_file(), tokenizer_pathprecision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"fabric = L.Fabric(devices=1, precision=precision)print("Loading model ...", file=sys.stderr)t0 = time.time()with lazy_load(checkpoint_path) as checkpoint:name = llama_model_lookup(checkpoint)with fabric.init_module(empty_init=True), quantization(mode=quantize):model = LLaMA.from_name(name)model.load_state_dict(checkpoint)print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)model.eval()model = fabric.setup(model)tokenizer = Tokenizer(tokenizer_path)encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)prompt_length = encoded.size(0)L.seed_everything(1234)for i in range(num_samples):t0 = time.perf_counter()y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)t = time.perf_counter() - t0model.reset_cache()print(tokenizer.decode(y))tokens_generated = y.size(0) - prompt_lengthprint(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)if fabric.device.type == "cuda":print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)if __name__ == "__main__":from jsonargparse import CLItorch.set_float32_matmul_precision("high")warnings.filterwarnings(# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31"ignore", message="ComplexHalf support is experimental and many operators don't support it yet")warnings.filterwarnings(# Triggered in bitsandbytes/autograd/_functions.py:298"ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",)CLI(main)

main()

"""Generates text samples based on a pre-trained LLaMA model and tokenizer.Args:prompt: The prompt string to use for generating the samples.num_samples: The number of text samples to generate.(Its effect is overridden by `max_new_tokens`, if also set.)max_new_tokens: The number of generation steps to take.(number of generate tokens )top_k: The number of top most probable tokens to consider in the sampling process.temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples.checkpoint_path: The checkpoint path to load.tokenizer_path: The tokenizer path to load.quantize: Whether to quantize the model and using which method:``"llm.int8"``: LLM.int8() mode,``"gptq.int4"``: GPTQ 4-bit mode.
"""


https://zhuanlan.zhihu.com/p/657886517

Fabric()

r"""Fabric accelerates your PyTorch training or inference code with minimal changes required.Fabric 加速你的 PyTorch 训练或推理代码,所需的更改最小。- Automatic placement of models and data onto the device.- 自动将模型和数据放置到设备上。- Automatic support for mixed and double precision (smaller memory footprint).- 自动支持混合精度和双精度(较小的内存占用)。- Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies(data-parallel training, sharded training, etc.).- 在硬件(CPU、GPU、TPU)和分布式训练策略(数据并行训练、分片训练等)之间无缝切换。- Automated spawning of processes, no launch utilities required.- 自动生成进程,无需启动工具。- Multi-node support.- 支持多节点训练。Args:accelerator: The hardware to run on. Possible choices are:``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.accelerator: 运行的硬件。可能的选择有:``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``。strategy: Strategy for how to run across multiple devices. Possible choices are:``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.strategy: 跨多个设备运行的策略。可能的选择有:``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``。devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.The value applies per node.devices: 训练时使用的设备数量(``int``),或要训练的 GPU(``list`` 或 ``str``),或 ``"auto"``。该值适用于每个节点。num_nodes: Number of GPU nodes for distributed training.num_nodes: 用于分布式训练的 GPU 节点数量。precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),or bfloat16 precision AMP (``"bf16-mixed"``).precision: 双精度(``"64"``),全精度(``"32"``),半精度 AMP(``"16-mixed"``),或 bfloat16 精度 AMP(``"bf16-mixed"``)。plugins: One or several custom pluginsplugins: 一个或多个自定义插件callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods thatcan be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user.callbacks: 单个回调或回调列表。回调可以包含任何用户可以通过 :meth:`~lightning.fabric.fabric.Fabric.call` 调用的任意方法。loggers: A single logger or a list of loggers. See :meth:`~lightning.fabric.fabric.Fabric.log` for moreinformation.loggers: 单个日志记录器或日志记录器列表。有关更多信息,请参见 :meth:`~lightning.fabric.fabric.Fabric.log`。
"""

lazy_load()

定义了一个名为 lazy_load 的类,它用于延迟加载和管理一个 PyTorch 文件:lazy_load 类
__init__ 方法
python
def __init__(self, fn):self.zf = torch._C.PyTorchFileReader(str(fn))with BytesIO(self.zf.get_record("data.pkl")) as pkl:mup = LazyLoadingUnpickler(pkl, self)self.sd = mup.load()
self.zf = torch._C.PyTorchFileReader(str(fn)):创建一个 PyTorchFileReader 实例,用于读取指定文件 (fn) 的内容。这个文件是 PyTorch 保存的文件,通常是 .pt 或 .pth 文件。
str(fn) 确保文件路径被正确转换为字符串。
with BytesIO(self.zf.get_record("data.pkl")) as pkl::从 PyTorchFileReader 中提取名为 "data.pkl" 的记录,并用 BytesIO 创建一个内存中的字节流对象 pkl。
BytesIO 用于在内存中读写二进制数据。
mup = LazyLoadingUnpickler(pkl, self):创建一个 LazyLoadingUnpickler 实例 mup,它负责处理 pkl 中的数据。这里假设 LazyLoadingUnpickler 是自定义的类,用于延迟加载和解码 Pickle 数据。
self.sd = mup.load():调用 mup.load() 方法来加载数据,并将结果存储在 self.sd 属性中。这个过程可能会涉及到数据的反序列化。
__enter__ 方法
python
def __enter__(self):return self.sd
这个方法允许 lazy_load 实例在上下文管理器(with 语句)中使用。__enter__ 返回 self.sd,使得 with 语句块内部可以直接访问加载的数据。
__exit__ 方法
python
def __exit__(self, exc_type, exc_val, exc_tb):del self.zf  # I don't think there is a way to force closing...self.zf = None
这个方法用于处理退出上下文管理器时的清理工作。
del self.zf: 尝试删除 self.zf 对象。由于 self.zf 是一个 PyTorchFileReader 实例,删除对象的作用是释放相关资源。
self.zf = None: 另一种释放资源的方式,将 self.zf 设置为 None,以确保它不再被引用。
总结
这个类的设计用于懒加载 PyTorch 文件中的数据。它实现了上下文管理协议,使得数据可以在 with 语句块中方便地访问,并且在退出时尝试释放相关资源。

LazyLoadingUnpickler()

定义了一个 LazyLoadingUnpickler 类,继承自 pickle.Unpickler,用于处理 PyTorch 对象的延迟加载。以下是对每个部分的详细解释:__init__ 方法
python
def __init__(self, file, zipfile_context):super().__init__(file)self.zipfile_context = zipfile_context
file: 传入的文件对象(通常是一个字节流),用于反序列化。
zipfile_context: 额外的上下文信息,用于延迟加载的实现。这通常是一个包含 PyTorch 文件读取信息的对象。
super().__init__(file): 调用父类 pickle.Unpickler 的初始化方法,传入文件对象。
self.zipfile_context: 保存额外的上下文信息,用于稍后延迟加载。
find_class 方法
python
def find_class(self, module, name):res = super().find_class(module, name)if module == "torch._utils" and name == "_rebuild_tensor_v2":return functools.partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self)elif module == "torch._tensor" and name == "_rebuild_from_type_v2":return functools.partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self)elif module == "torch._utils" and name == "_rebuild_parameter":return functools.partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self)return res
super().find_class(module, name): 调用父类的 find_class 方法,查找并返回指定模块和类名的类。
模块和类名检查:
当模块是 "torch._utils" 且类名是 "_rebuild_tensor_v2" 时,返回一个 functools.partial 对象,部分应用 NotYetLoadedTensor.rebuild_tensor_v2 方法,并传入 archiveinfo=self。
当模块是 "torch._tensor" 且类名是 "_rebuild_from_type_v2" 时,返回一个 functools.partial 对象,部分应用 NotYetLoadedTensor.rebuild_from_type_v2 方法。
当模块是 "torch._utils" 且类名是 "_rebuild_parameter" 时,返回一个 functools.partial 对象,部分应用 NotYetLoadedTensor.rebuild_parameter 方法。
functools.partial: 允许创建一个新的函数,其中一些参数已经预先指定,这里是为了在实际调用时延迟具体的处理逻辑。
返回值: 如果模块和类名不匹配,返回父类的结果。
persistent_load 方法
python
def persistent_load(self, pid):name, cls, fn, device, size = pidwith warnings.catch_warnings():warnings.simplefilter("ignore")s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")s.archiveinfo = pidreturn s
pid: 一个包含多个信息的元组 (name, cls, fn, device, size),用于标识持久化数据的加载信息。
warnings.catch_warnings(): 捕获并管理警告信息。
warnings.simplefilter("ignore"): 忽略警告信息,以便在加载过程中不会产生干扰。
torch.storage.TypedStorage(dtype=cls().dtype, device="meta"): 创建一个 TypedStorage 对象,指定数据类型和设备。device="meta" 表示数据存储在元数据设备中,实际上并没有分配真实的存储空间。
s.archiveinfo = pid: 将持久化标识信息存储到 TypedStorage 对象中。
返回值: 返回创建的 TypedStorage 对象。
总结
LazyLoadingUnpickler 主要用于在反序列化 PyTorch 对象时实现延迟加载。这种方法使得在加载大数据文件时可以更高效地管理内存和计算资源。find_class 方法用于动态创建用于延迟加载的对象,而 persistent_load 方法则用于处理持久化存储数据的加载。

llama_model_lookup() 

init_module() 

def init_module(self, empty_init: Optional[bool] = None) -> ContextManager:"""Instantiate the model and its parameters under this context manager to reduce peak memory usage.
在这个上下文管理器下实例化模型及其参数,以减少峰值内存使用。The parameters get created on the device and with the right data type right away without wasting memory being allocated unnecessarily.
参数会直接在设备上创建,并且使用正确的数据类型,从而避免了不必要的内存分配浪费。Args:
参数:empty_init: Whether to initialize the model with empty weights (uninitialized memory).
empty_init: 是否使用空权重(未初始化的内存)来初始化模型。If ``None``, the strategy will decide. Some strategies may not support all options.
如果``None``,则策略将决定。一些策略可能不支持所有选项。Set this to ``True`` if you are loading a checkpoint into a large model.
如果你正在将检查点加载到大型模型中,将其设置为``True``。"""self._validate_launched()return self._strategy.module_init_context(empty_init=empty_init)
module_init_context()  
 def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:"""A context manager wrapping the model instantiation.
一个包装模型实例化的上下文管理器。Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model.
在这里,策略可以控制模型参数的创建方式(设备、数据类型)或对模型应用其他修补。Args:
参数:empty_init: Whether to initialize the model with empty weights (uninitialized memory).
empty_init: 是否使用空权重(未初始化的内存)来初始化模型。If ``None``, the strategy will decide. Some strategies may not support all options.
如果``None``,则策略将决定。一些策略可能不支持所有选项。"""precision_module_ctx = self.precision.module_init_context()stack = ExitStack()stack.enter_context(self.root_device)stack.enter_context(_EmptyInit(enabled=bool(empty_init)))stack.enter_context(precision_module_ctx)return stack

quantization() 

@contextmanager
def quantization(mode: str = None):quantized_linear_cls = Noneif mode == 'llm.int8':from .quantization import Linear8bitLtquantized_linear_cls = Linear8bitLtelif mode == 'gptq.int4':from .quantization import ColBlockQuantizedLinearquantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)elif mode == 'gptq.int8':from .quantization import ColBlockQuantizedLinearquantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)elif mode is not None:raise ValueError(f"Unknown quantization mode: {mode}")enabled = mode is not Nonetorch_linear_cls = torch.nn.Linearif enabled:torch.nn.Linear = quantized_linear_clsyieldif enabled:torch.nn.Linear = torch_linear_cls

model 

setup() 

    def setup(self,module: nn.Module,*optimizers: Optimizer,move_to_device: bool = True,_reapply_compile: bool = True,) -> Any:  # no specific return because the way we want our API to look does not play well with mypyr"""Set up a model and its optimizers for accelerated training.
为加速训练设置模型及其优化器。Args:
参数:module: A :class:`torch.nn.Module` to set up
module: 要设置的 :class:`torch.nn.Module`*optimizers: The optimizer(s) to set up (no optimizers is also possible)
*optimizers: 要设置的优化器(也可以不设置优化器)move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
move_to_device: 如果设置为``True``(默认值),则将模型移动到正确的设备。设置为``False`` and alternatively use :meth:`to_device` manually.并可以手动使用 :meth:`to_device`。_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
_reapply_compile: 如果``True``(默认值),且模型之前已``torch.compile``,则corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the相应的 :class:`~torch._dynamo.OptimizedModule` 包装器将被移除,并在模型被策略设置好后重新应用same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,相同的设置(例如,模型被 DDP、FSDP 等包装之后)。如果编译 DDP/FSDP 造成问题,设置为``False``。Returns:
返回:The tuple containing wrapped module and the optimizers, in the same order they were passed in.
一个包含包装的模块和优化器的元组,顺序与传入时相同。"""

tokenizer

 

    def encode(self,string: str,bos: bool = True,eos: bool = False,max_length: int = -1,pad: bool = False,device: Optional[torch.device] = None) -> torch.Tensor:tokens = self.processor.encode(string)if bos:tokens = [self.bos_id] + tokensif eos:tokens = tokens + [self.eos_id]if max_length > 0:tokens = tokens[:max_length]if pad and len(tokens) < max_length:tokens += [self.pad_id] * (max_length - len(tokens))return torch.tensor(tokens, dtype=torch.int, device=device)def decode(self, tokens: torch.Tensor) -> str:return self.processor.decode(tokens.tolist())

 generate()

@torch.no_grad()
def generate(model: LLaMA,idx: torch.Tensor,max_new_tokens: int,*,max_seq_length: Optional[int] = None,temperature: float = 1.0,top_k: Optional[int] = None,eos_id: Optional[int] = None,
) -> torch.Tensor:"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
接收一个条件序列(提示)作为输入,并继续生成所请求的数量的标记。The implementation of this function is modified from A. Karpathy's nanoGPT.
此函数的实现改编自 A. Karpathy 的 nanoGPT。Args:
参数:model: The model to use.
model: 要使用的模型。idx: Tensor of shape (T) with indices of the prompt sequence.
idx: 形状为 (T) 的张量,其中包含提示序列的索引。max_new_tokens: The number of new tokens to generate.
max_new_tokens: 要生成的新分词数量。max_seq_length: The maximum sequence length allowed.
max_seq_length: 允许的最大序列长度。temperature: Scales the predicted logits by 1 / temperature
temperature: 通过 1 / temperature 对预测的 logits 进行缩放。top_k: If specified, only sample among the tokens with the k highest probabilities
top_k: 如果指定,只从概率最高的 k 个标记中进行采样。eos_id: If specified, stop generating any more token once the <eos> token is triggered
eos_id: 如果指定,一旦触发 <eos> 标记,停止生成更多标记。"""

 https://pytorch.ac.cn/xla/release/2.1/index.htmlXLA 设备上的 PyTorch

model

    def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:return build_rope_cache(seq_len=self.config.block_size,n_elem=self.config.n_embd // self.config.n_head,dtype=idx.dtype,device=idx.device,)

temperature

温度越低,结果的差距越大,会使概率分布更加尖锐,从而使得模型更倾向于选择最高概率的类别。

topk()  

def topk(input: Tensor, k: Union[_int, SymInt], dim: _int = -1, largest: _bool = True, sorted: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.topk: r"""topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)返回给定 input 张量在指定维度上最大的 k 个元素。如果没有给定 dim,则选择 input 张量的最后一个维度。如果 largest 设置为 False,则返回 k 个最小元素。函数返回一个命名元组 (values, indices),其中 values 和 indices 分别是输入张量在指定维度 dim 上最大的 k 个元素及其索引。布尔选项 sorted 如果为 True,则确保返回的 k 个元素按顺序排列。参数:input (Tensor): 输入张量。
k (int): "top-k" 中的 k 值。
dim (int, optional): 排序的维度。
largest (bool, optional): 控制是否返回最大还是最小元素。
sorted (bool, optional): 控制是否返回排序后的元素。
关键字参数:out (tuple, optional): 可选的输出元组 (Tensor, LongTensor),可以作为输出缓冲区使用。
示例:python
>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))"""

torch.multinomial

def multinomial(input: Tensor, num_samples: _int, replacement: _bool = False, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: r"""def multinomial(input: Tensor, num_samples: _int, replacement: _bool = False, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor:r"""multinomial(input, num_samples, replacement=False, *, generator=None, out=None) -> LongTensor返回一个张量,其中每一行包含 :attr:`num_samples` 个从对应行的多项分布中采样的索引。更严格地说,是从多元分布中采样,更多细节请参考 torch.distributions.multinomial.Multinomial。.. note:::attr:`input` 的行不需要和为 1(在这种情况下,我们使用值作为权重),但必须是非负的、有穷的,并且和不为零。索引按从左到右的顺序排列,依据每个索引被采样的顺序(第一个样本放在第一列)。如果 :attr:`input` 是一个向量,:attr:`out` 是一个大小为 :attr:`num_samples` 的向量。如果 :attr:`input` 是一个有 `m` 行的矩阵,则 :attr:`out` 是一个形状为:math:`(m \times \text{num\_samples})` 的矩阵。如果 `replacement` 为 ``True``,则样本是有放回的。如果不是,则样本是无放回的,这意味着一旦为某行绘制了一个样本索引,在该行中不能再次绘制相同的索引。.. note::当无放回采样时,:attr:`num_samples` 必须小于 :attr:`input` 中非零元素的数量(如果 `input` 是矩阵,则为每行的非零元素的最小数量)。Args:input (Tensor): 包含概率的输入张量num_samples (int): 要绘制的样本数量replacement (bool, optional): 是否允许重复抽样关键字参数:generator (:class:`torch.Generator`, optional): 用于采样的伪随机数生成器out (Tensor, optional): 输出张量。示例::>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # 创建一个权重张量>>> torch.multinomial(weights, 2)tensor([1, 2])>>> torch.multinomial(weights, 4) # 错误!RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False,not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320>>> torch.multinomial(weights, 4, replacement=True)tensor([ 2,  1,  1,  1])""""""

 model.reset_cache()

 

Pytorch清空显存缓冲区(torch.cuda.empty_cache)_pytorch 清空显存-CSDN博客 
Pytorch 如何在使用模型后清除GPU内存|极客教程

这篇关于lit-llama代码解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1135179

相关文章

Linux中shell解析脚本的通配符、元字符、转义符说明

《Linux中shell解析脚本的通配符、元字符、转义符说明》:本文主要介绍shell通配符、元字符、转义符以及shell解析脚本的过程,通配符用于路径扩展,元字符用于多命令分割,转义符用于将特殊... 目录一、linux shell通配符(wildcard)二、shell元字符(特殊字符 Meta)三、s

python实现pdf转word和excel的示例代码

《python实现pdf转word和excel的示例代码》本文主要介绍了python实现pdf转word和excel的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价... 目录一、引言二、python编程1,PDF转Word2,PDF转Excel三、前端页面效果展示总结一

在MyBatis的XML映射文件中<trim>元素所有场景下的完整使用示例代码

《在MyBatis的XML映射文件中<trim>元素所有场景下的完整使用示例代码》在MyBatis的XML映射文件中,trim元素用于动态添加SQL语句的一部分,处理前缀、后缀及多余的逗号或连接符,示... 在MyBATis的XML映射文件中,<trim>元素用于动态地添加SQL语句的一部分,例如SET或W

使用C#代码计算数学表达式实例

《使用C#代码计算数学表达式实例》这段文字主要讲述了如何使用C#语言来计算数学表达式,该程序通过使用Dictionary保存变量,定义了运算符优先级,并实现了EvaluateExpression方法来... 目录C#代码计算数学表达式该方法很长,因此我将分段描述下面的代码片段显示了下一步以下代码显示该方法如

python多进程实现数据共享的示例代码

《python多进程实现数据共享的示例代码》本文介绍了Python中多进程实现数据共享的方法,包括使用multiprocessing模块和manager模块这两种方法,具有一定的参考价值,感兴趣的可以... 目录背景进程、进程创建进程间通信 进程间共享数据共享list实践背景 安卓ui自动化框架,使用的是

使用Python实现批量访问URL并解析XML响应功能

《使用Python实现批量访问URL并解析XML响应功能》在现代Web开发和数据抓取中,批量访问URL并解析响应内容是一个常见的需求,本文将详细介绍如何使用Python实现批量访问URL并解析XML响... 目录引言1. 背景与需求2. 工具方法实现2.1 单URL访问与解析代码实现代码说明2.2 示例调用

SSID究竟是什么? WiFi网络名称及工作方式解析

《SSID究竟是什么?WiFi网络名称及工作方式解析》SID可以看作是无线网络的名称,类似于有线网络中的网络名称或者路由器的名称,在无线网络中,设备通过SSID来识别和连接到特定的无线网络... 当提到 Wi-Fi 网络时,就避不开「SSID」这个术语。简单来说,SSID 就是 Wi-Fi 网络的名称。比如

SpringBoot生成和操作PDF的代码详解

《SpringBoot生成和操作PDF的代码详解》本文主要介绍了在SpringBoot项目下,通过代码和操作步骤,详细的介绍了如何操作PDF,希望可以帮助到准备通过JAVA操作PDF的你,项目框架用的... 目录本文简介PDF文件简介代码实现PDF操作基于PDF模板生成,并下载完全基于代码生成,并保存合并P

SpringCloud配置动态更新原理解析

《SpringCloud配置动态更新原理解析》在微服务架构的浩瀚星海中,服务配置的动态更新如同魔法一般,能够让应用在不重启的情况下,实时响应配置的变更,SpringCloud作为微服务架构中的佼佼者,... 目录一、SpringBoot、Cloud配置的读取二、SpringCloud配置动态刷新三、更新@R

使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)

《使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)》在现代软件开发中,处理JSON数据是一项非常常见的任务,无论是从API接口获取数据,还是将数据存储为JSON格式,解析... 目录1. 背景介绍1.1 jsON简介1.2 实际案例2. 准备工作2.1 环境搭建2.1.1 添加