trohrbaugh commited on
Commit
dda6a6d
·
verified ·
1 Parent(s): c899cd1

Upload modeling_stable_diffcoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_stable_diffcoder.py +274 -0
modeling_stable_diffcoder.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 ByteDance Ltd. and/or its affiliates
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, DynamicCache
9
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
10
+ from transformers.generation.utils import GenerationConfig
11
+
12
+
13
+ class StableDiffcoderForCausalLM(LlamaForCausalLM):
14
+ def _get_num_transfer_tokens(self, mask_map, steps):
15
+ # Only bs == 1 is supported for now
16
+ mask_num = mask_map.sum().long().item()
17
+
18
+ base = mask_num // steps
19
+ remainder = mask_num % steps
20
+
21
+ num_transfer_tokens = torch.full(
22
+ (steps,), fill_value=base, device=mask_map.device, dtype=torch.long
23
+ )
24
+
25
+ num_transfer_tokens[:remainder] += 1
26
+
27
+ return num_transfer_tokens
28
+
29
+ def _make_block_causal_mask(
30
+ self, seq_len, block_size=2, device=None, dtype=torch.bfloat16
31
+ ):
32
+ num_blocks = (seq_len + block_size - 1) // block_size
33
+ block_mask = torch.tril(
34
+ torch.ones((num_blocks, num_blocks), dtype=torch.bool, device=device)
35
+ )
36
+ local_block = torch.ones(
37
+ (block_size, block_size), dtype=torch.bool, device=device
38
+ )
39
+ mask = block_mask.kron(local_block)[:seq_len, :seq_len]
40
+
41
+ attention_mask = mask.float()
42
+ attention_mask.masked_fill_(~mask, -torch.inf)
43
+ attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).to(dtype)
44
+ return attention_mask
45
+
46
+ def _get_transfer_index(
47
+ self,
48
+ logits,
49
+ temperature,
50
+ remasking,
51
+ mask_index,
52
+ x,
53
+ num_transfer_token,
54
+ threshold=None,
55
+ shift=False,
56
+ ):
57
+ def add_gumbel_noise(logits, temperature):
58
+ if temperature == 0:
59
+ return logits
60
+ logits = logits.to(torch.float64)
61
+ noise = torch.rand_like(logits, dtype=torch.float64)
62
+ gumbel_noise = (-torch.log(noise)) ** temperature
63
+ return logits.exp() / gumbel_noise
64
+
65
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
66
+ x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
67
+ if shift:
68
+ x0 = torch.cat([x[:, :1], x0[:, :-1]], dim=-1)
69
+ pad = torch.zeros_like(logits[:, :1])
70
+ logits = torch.cat([pad, logits[:, :-1]], dim=1)
71
+ if remasking == "low_confidence":
72
+ p = F.softmax(logits.to(torch.float64), dim=-1)
73
+ x0_p = torch.squeeze(
74
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
75
+ ) # b, l
76
+ elif remasking == "random":
77
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
78
+ else:
79
+ raise NotImplementedError(remasking)
80
+
81
+ x0 = torch.where(mask_index, x0, x)
82
+ confidence = torch.where(mask_index, x0_p, -np.inf)
83
+
84
+ transfer_map = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
85
+ if threshold is not None:
86
+ num_transfer_token = mask_index.sum(dim=1, keepdim=True)
87
+ _, select_index = torch.topk(confidence[0], k=num_transfer_token)
88
+ transfer_map[0, select_index] = True
89
+ if threshold is not None:
90
+ for k in range(1, num_transfer_token):
91
+ if confidence[0, select_index[k]] < threshold:
92
+ transfer_map[0, select_index[k]] = False
93
+ return x0, transfer_map
94
+
95
+ @torch.no_grad()
96
+ def generate_block(
97
+ self,
98
+ input_ids: torch.LongTensor,
99
+ steps=128,
100
+ gen_length=128,
101
+ block_length=4,
102
+ temperature=0.0,
103
+ remasking="low_confidence",
104
+ tokenizer=None,
105
+ mask_id=5,
106
+ threshold=0.95,
107
+ shift=False,
108
+ eos_id=None,
109
+ ):
110
+ x = torch.cat(
111
+ [
112
+ input_ids,
113
+ torch.full(
114
+ (input_ids.shape[0], gen_length),
115
+ mask_id,
116
+ dtype=torch.long,
117
+ device=input_ids.device,
118
+ ),
119
+ ],
120
+ dim=1,
121
+ )
122
+
123
+ assert gen_length % block_length == 0, (
124
+ "gen_length must be divisible by block_length"
125
+ )
126
+ gen_blocks = gen_length // block_length
127
+
128
+ assert steps % gen_blocks == 0, (
129
+ "steps must be divisible by the number of generation blocks"
130
+ )
131
+ steps = steps // gen_blocks
132
+
133
+ assert x.shape[0] == 1, (
134
+ "Only batch size of 1 is supported for block-wise generation currently."
135
+ )
136
+
137
+ prompt_length = input_ids.shape[1]
138
+ gen_block_list = [block_length for _ in range(gen_blocks)]
139
+
140
+ # Fix 3: Only handle residual blocks if the prompt length is NOT cleanly divisible
141
+ remainder = prompt_length % block_length
142
+ if remainder != 0:
143
+ res_block = block_length - remainder
144
+ gen_block_list = [res_block] + gen_block_list
145
+ gen_block_list[-1] = block_length - res_block
146
+ gen_blocks += 1
147
+ cum_block = [sum(gen_block_list[: i + 1]) for i in range(len(gen_block_list))]
148
+
149
+ block_diffusion_attention_mask = self._make_block_causal_mask(
150
+ prompt_length + gen_length,
151
+ block_length,
152
+ self.device,
153
+ dtype=torch.bfloat16,
154
+ )
155
+
156
+ past_key_values = DynamicCache()
157
+
158
+ nfe = 0
159
+ final_flag = False
160
+ prefill_length = prompt_length // block_length * block_length
161
+
162
+ if prefill_length > 0:
163
+ cur_attn_mask = block_diffusion_attention_mask[
164
+ ..., :prefill_length, :prefill_length
165
+ ]
166
+ # Fix 1: Explicitly pass cache_position for newer transformers prefill
167
+ # actually not necessary since transformers will automatically generate it for prefilling
168
+ # if unspecified, but the official `generate` method does pass it,
169
+ # so we follow that for consistency and to avoid potential issues in future transformers updates
170
+ cache_pos = torch.arange(prefill_length, device=x.device)
171
+ self(
172
+ x[:, :prefill_length],
173
+ past_key_values=past_key_values,
174
+ attention_mask=cur_attn_mask,
175
+ use_cache=True,
176
+ cache_position=cache_pos,
177
+ )
178
+
179
+ for block_id, block_size in enumerate(gen_block_list):
180
+ block_start = (
181
+ prompt_length + cum_block[block_id - 1]
182
+ if block_id > 0
183
+ else prefill_length
184
+ )
185
+ block_end = prompt_length + cum_block[block_id]
186
+
187
+ block_mask_map = x[:, block_start:block_end] == mask_id
188
+ num_transfer_tokens = self._get_num_transfer_tokens(block_mask_map, steps)
189
+
190
+ replace_position = torch.zeros_like(x, dtype=torch.bool)
191
+ replace_position[:, block_start:block_end] = True
192
+
193
+ for token_count in num_transfer_tokens:
194
+ if token_count > 0:
195
+ nfe += 1
196
+ mask_map = x[:, block_start:block_end] == mask_id
197
+ attention_mask = block_diffusion_attention_mask[
198
+ ..., block_start:block_end, :block_end
199
+ ]
200
+ output = self(
201
+ x[:, block_start:block_end],
202
+ attention_mask=attention_mask,
203
+ past_key_values=past_key_values,
204
+ use_cache=True,
205
+ cache_position=replace_position.nonzero(as_tuple=True)[1],
206
+ )
207
+ logits = output.logits
208
+
209
+ past_key_values.crop(block_start)
210
+
211
+ x0, transfer_map = self._get_transfer_index(
212
+ logits,
213
+ temperature,
214
+ remasking,
215
+ mask_map,
216
+ x[:, block_start:block_end],
217
+ token_count.item() if threshold is None else None,
218
+ threshold,
219
+ shift=shift,
220
+ )
221
+ x[:, block_start:block_end][transfer_map] = x0[transfer_map]
222
+
223
+ if (x[:, block_start:block_end] == mask_id).sum() == 0:
224
+
225
+ # Fix 2: Calculate where the generated tokens ACTUALLY start in this block
226
+ gen_start = max(block_start, prompt_length)
227
+
228
+ if (
229
+ eos_id is not None
230
+ and gen_start < block_end
231
+ and (x[:, gen_start:block_end] == eos_id).sum() > 0
232
+ ):
233
+ final_flag = True
234
+ x = x[:, :block_end]
235
+ eos_pos = (x[:, gen_start:block_end] == eos_id).nonzero(as_tuple=True)[1][0].item() + gen_start
236
+ x[0, eos_pos:] = eos_id
237
+ break
238
+
239
+ nfe += 1
240
+ self(
241
+ x[:, block_start:block_end],
242
+ attention_mask=block_diffusion_attention_mask[
243
+ ..., block_start:block_end, :block_end
244
+ ],
245
+ past_key_values=past_key_values,
246
+ use_cache=True,
247
+ cache_position=replace_position.nonzero(as_tuple=True)[1],
248
+ )
249
+ break
250
+
251
+ if final_flag:
252
+ break
253
+
254
+ return x, nfe
255
+
256
+ @torch.no_grad()
257
+ def generate(
258
+ self,
259
+ input_ids=None,
260
+ generation_config: GenerationConfig = None,
261
+ **kwargs,
262
+ ):
263
+ if input_ids is None:
264
+ raise ValueError("input_ids must be provided")
265
+
266
+ if generation_config is None:
267
+ generation_config = self.generation_config
268
+
269
+ output_ids, nfe = self.generate_block(
270
+ input_ids=input_ids,
271
+ **kwargs,
272
+ )
273
+
274
+ return output_ids