当前位置:   article > 正文

深入理解GLM4模型的流式文本生成:从代码到概念_glm4流式输出

glm4流式输出

自然语言处理(NLP)领域,大型语言模型(LLM)的应用越来越广泛。今天,我们将深入探讨GLM4模型的流式文本生成过程,通过分析一段Python代码来理解其工作原理和关键概念。

1. 代码概览

首先,让我们看一下核心代码:

@torch.inference_mode()
async def generate_stream_glm4(params):
    # ... (参数处理代码)
    
    messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
    inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    
    # ... (采样参数设置代码)
    
    async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
        # ... (输出处理代码)
        yield ret
    
    gc.collect()
    torch.cuda.empty_cache()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

这段代码定义了一个异步生成函数,用于流式生成GLM4模型的输出。让我们逐步分解这段代码,了解其中的关键概念和技术。

2. 推理模式和异步编程

2.1 推理模式

@torch.inference_mode()
  • 1

这个装饰器表示函数在PyTorch的推理模式下运行。在推理模式下:

  • 禁用梯度计算,减少内存使用
  • 某些操作会使用优化的实现,提高性能

推理模式适用于不需要反向传播的场景,如模型评估和预测。

2.2 异步编程

async def generate_stream_glm4(params):
    # ...
    async for output in engine.generate(...):
        # ...
  • 1
  • 2
  • 3
  • 4

使用asyncawait关键字进行异步编程,可以提高I/O密集型任务的效率。在这里,异步编程允许在生成文本时同时处理其他任务,提升overall性能。

3. 输入处理和模型准备

3.1 消息处理

messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
  • 1

这一步处理输入消息,可能包括:

  • 格式化消息
  • 应用工具选择
  • 处理特殊指令

3.2 应用聊天模板

inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
  • 1

这一步将处理后的消息应用到预定义的聊天模板中,为模型输入做准备。

4. 采样参数设置

params_dict = {
    "temperature": temperature,
    "top_p": top_p,
    "repetition_penalty": repetition_penalty,
    # ... (其他参数)
}
sampling_params = SamplingParams(**params_dict)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

这里设置了文本生成的关键参数:

  • temperature: 控制输出的随机性
  • top_p: 使用核采样控制输出的多样性
  • repetition_penalty: 降低重复文本的概率

正确设置这些参数对生成高质量、多样化的文本至关重要。

5. 流式生成

async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
    # ... (输出处理代码)
    yield ret
  • 1
  • 2
  • 3

使用async for进行流式生成,允许实时返回部分生成结果。这对于需要快速响应的应用场景(如聊天机器人)非常有用。

6. 资源管理

gc.collect()
torch.cuda.empty_cache()
  • 1
  • 2

在生成完成后,进行垃圾回收和GPU内存清理,确保资源得到有效释放。

7. 代码示例:简化版流式生成

下面是一个简化的流式生成示例,展示了核心概念:

import asyncio
from transformers import AutoTokenizer, AutoModelForCausalLM

async def generate_stream(prompt, model, tokenizer, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    
    for i in range(max_length):
        outputs = model.generate(**inputs, max_length=inputs["input_ids"].shape[1] + 1, do_sample=True)
        next_token = outputs[0][-1]
        yield tokenizer.decode([next_token])
        
        inputs["input_ids"] = torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=-1)

async def main():
    model_name = "gpt2"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    
    prompt = "Once upon a time"
    async for token in generate_stream(prompt, model, tokenizer):
        print(token, end='', flush=True)

asyncio.run(main())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

这个示例展示了如何使用异步编程和流式生成实现简单的文本生成。

8. 完整代码

@torch.inference_mode()
async def generate_stream_glm4(params):
    messages = params["messages"]
    tools = params["tools"]
    tool_choice = params["tool_choice"]
    temperature = float(params.get("temperature", 1.0))
    repetition_penalty = float(params.get("repetition_penalty", 1.0))
    top_p = float(params.get("top_p", 1.0))
    max_new_tokens = int(params.get("max_tokens", 8192))

    messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
    inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    params_dict = {
        "n": 1,
        "best_of": 1,
        "presence_penalty": 1.0,
        "frequency_penalty": 0.0,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": -1,
        "repetition_penalty": repetition_penalty,
        "use_beam_search": False,
        "length_penalty": 1,
        "early_stopping": False,
        "stop_token_ids": [151329, 151336, 151338],
        "ignore_eos": False,
        "max_tokens": max_new_tokens,
        "logprobs": None,
        "prompt_logprobs": None,
        "skip_special_tokens": True,
    }
    sampling_params = SamplingParams(**params_dict)
    async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
        output_len = len(output.outputs[0].token_ids)
        input_len = len(output.prompt_token_ids)
        ret = {
            "text": output.outputs[0].text,
            "usage": {
                "prompt_tokens": input_len,
                "completion_tokens": output_len,
                "total_tokens": output_len + input_len
            },
            "finish_reason": output.outputs[0].finish_reason,
        }
        yield ret
    gc.collect()
    torch.cuda.empty_cache()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

结论

通过分析GLM4模型的流式生成代码,我们深入了解了大型语言模型在实际应用中的工作原理。从异步编程到采样策略,从输入处理到资源管理,每个环节都在追求更高效、更灵活的文本生成体验。随着NLP技术的不断发展,我们可以期待更多创新应用的出现。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/973515
推荐阅读
相关标签
  

闽ICP备14008679号