fix sequence length in santacoder and introduce new model type
#23
by mayank-mishra - opened
- config.json +1 -1
- modeling_gpt2_mq.py +200 -23
config.json CHANGED
@@ -14,7 +14,7 @@ | |
14 | "eos_token_id": 50256, |
15 | "initializer_range": 0.02, |
16 | "layer_norm_epsilon": 1e-05, |
17 | - "model_type": " |
18 | "n_embd": 2048, |
19 | "n_head": 16, |
20 | "n_inner": 8192, |
| |
14 | "eos_token_id": 50256, |
15 | "initializer_range": 0.02, |
16 | "layer_norm_epsilon": 1e-05, |
17 | + "model_type": "santacoder", |
18 | "n_embd": 2048, |
19 | "n_head": 16, |
20 | "n_inner": 8192, |
modeling_gpt2_mq.py CHANGED
@@ -1,39 +1,21 @@ | |
1 | """PyTorch OpenAI GPT-2 model modified with MultiQuery attention""" |
2 | |
3 | |
4 | - import math |
5 | - import os |
6 | - from dataclasses import dataclass |
7 | from typing import Optional, Tuple, Union |
8 | |
9 | import torch |
10 | import torch.utils.checkpoint |
11 | from torch import nn |
12 | from torch.cuda.amp import autocast |
13 | - |
14 | - |
15 | - from transformers.activations import ACT2FN |
16 | - from transformers.modeling_outputs import ( |
17 | - BaseModelOutputWithPastAndCrossAttentions, |
18 | - CausalLMOutputWithCrossAttentions, |
19 | - SequenceClassifierOutputWithPast, |
20 | - TokenClassifierOutput, |
21 | - ) |
22 | - from transformers.modeling_utils import PreTrainedModel, SequenceSummary |
23 | from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer |
24 | |
25 | - from transformers.utils import |
26 | - ModelOutput, |
27 | - add_code_sample_docstrings, |
28 | - add_start_docstrings, |
29 | - add_start_docstrings_to_model_forward, |
30 | - logging, |
31 | - replace_return_docstrings, |
32 | - ) |
33 | - from transformers.utils.model_parallel_utils import assert_device_map, get_device_map |
34 | from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel |
35 | - from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY |
36 | |
| |
37 | |
38 | |
39 | class GPT2MQAttention(nn.Module): |
@@ -329,6 +311,201 @@ class GPT2CustomModel(GPT2Model): | |
329 | # Initialize weights and apply final processing |
330 | self.post_init() |
331 | |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
332 | |
333 | class GPT2LMHeadCustomModel(GPT2LMHeadModel): |
334 | config_class = GPT2CustomConfig |
| |
1 | """PyTorch OpenAI GPT-2 model modified with MultiQuery attention""" |
2 | |
3 | |
| |
| |
| |
4 | from typing import Optional, Tuple, Union |
5 | |
6 | import torch |
7 | import torch.utils.checkpoint |
8 | from torch import nn |
9 | from torch.cuda.amp import autocast |
10 | + |
11 | + from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions |
| |
| |
| |
| |
| |
| |
| |
| |
12 | from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer |
13 | |
14 | + from transformers.utils import logging |
| |
| |
| |
| |
| |
| |
| |
| |
15 | from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel |
16 | + from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY |
17 | |
18 | + logger = logging.get_logger(__name__) |
19 | |
20 | |
21 | class GPT2MQAttention(nn.Module): |
| |
311 | # Initialize weights and apply final processing |
312 | self.post_init() |
313 | |
314 | + def forward( |
315 | + self, |
316 | + input_ids: Optional[torch.LongTensor] = None, |
317 | + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
318 | + attention_mask: Optional[torch.FloatTensor] = None, |
319 | + token_type_ids: Optional[torch.LongTensor] = None, |
320 | + position_ids: Optional[torch.LongTensor] = None, |
321 | + head_mask: Optional[torch.FloatTensor] = None, |
322 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
323 | + encoder_hidden_states: Optional[torch.Tensor] = None, |
324 | + encoder_attention_mask: Optional[torch.FloatTensor] = None, |
325 | + use_cache: Optional[bool] = None, |
326 | + output_attentions: Optional[bool] = None, |
327 | + output_hidden_states: Optional[bool] = None, |
328 | + return_dict: Optional[bool] = None, |
329 | + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: |
330 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
331 | + output_hidden_states = ( |
332 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
333 | + ) |
334 | + use_cache = use_cache if use_cache is not None else self.config.use_cache |
335 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
336 | + |
337 | + if input_ids is not None and inputs_embeds is not None: |
338 | + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
339 | + elif input_ids is not None: |
340 | + input_shape = input_ids.size() |
341 | + input_ids = input_ids.view(-1, input_shape[-1]) |
342 | + batch_size = input_ids.shape[0] |
343 | + elif inputs_embeds is not None: |
344 | + input_shape = inputs_embeds.size()[:-1] |
345 | + batch_size = inputs_embeds.shape[0] |
346 | + else: |
347 | + raise ValueError("You have to specify either input_ids or inputs_embeds") |
348 | + |
349 | + device = input_ids.device if input_ids is not None else inputs_embeds.device |
350 | + |
351 | + if token_type_ids is not None: |
352 | + token_type_ids = token_type_ids.view(-1, input_shape[-1]) |
353 | + if position_ids is not None: |
354 | + position_ids = position_ids.view(-1, input_shape[-1]) |
355 | + |
356 | + if past_key_values is None: |
357 | + past_length = 0 |
358 | + past_key_values = tuple([None] * len(self.h)) |
359 | + else: |
360 | + # this is different from GPT2 |
361 | + past_length = past_key_values[0][0].size(-1) |
362 | + if position_ids is None: |
363 | + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) |
364 | + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) |
365 | + |
366 | + # GPT2Attention mask. |
367 | + if attention_mask is not None: |
368 | + if batch_size <= 0: |
369 | + raise ValueError("batch_size has to be defined and > 0") |
370 | + attention_mask = attention_mask.view(batch_size, -1) |
371 | + # We create a 3D attention mask from a 2D tensor mask. |
372 | + # Sizes are [batch_size, 1, 1, to_seq_length] |
373 | + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] |
374 | + # this attention mask is more simple than the triangular masking of causal attention |
375 | + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. |
376 | + attention_mask = attention_mask[:, None, None, :] |
377 | + |
378 | + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for |
379 | + # masked positions, this operation will create a tensor which is 0.0 for |
380 | + # positions we want to attend and the dtype's smallest value for masked positions. |
381 | + # Since we are adding it to the raw scores before the softmax, this is |
382 | + # effectively the same as removing these entirely. |
383 | + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility |
384 | + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min |
385 | + |
386 | + # If a 2D or 3D attention mask is provided for the cross-attention |
387 | + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] |
388 | + if self.config.add_cross_attention and encoder_hidden_states is not None: |
389 | + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
390 | + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
391 | + if encoder_attention_mask is None: |
392 | + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
393 | + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
394 | + else: |
395 | + encoder_attention_mask = None |
396 | + |
397 | + # Prepare head mask if needed |
398 | + # 1.0 in head_mask indicate we keep the head |
399 | + # attention_probs has shape bsz x n_heads x N x N |
400 | + # head_mask has shape n_layer x batch x n_heads x N x N |
401 | + head_mask = self.get_head_mask(head_mask, self.config.n_layer) |
402 | + |
403 | + if inputs_embeds is None: |
404 | + inputs_embeds = self.wte(input_ids) |
405 | + position_embeds = self.wpe(position_ids) |
406 | + hidden_states = inputs_embeds + position_embeds |
407 | + |
408 | + if token_type_ids is not None: |
409 | + token_type_embeds = self.wte(token_type_ids) |
410 | + hidden_states = hidden_states + token_type_embeds |
411 | + |
412 | + hidden_states = self.drop(hidden_states) |
413 | + |
414 | + output_shape = input_shape + (hidden_states.size(-1),) |
415 | + |
416 | + presents = () if use_cache else None |
417 | + all_self_attentions = () if output_attentions else None |
418 | + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None |
419 | + all_hidden_states = () if output_hidden_states else None |
420 | + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): |
421 | + |
422 | + # Model parallel |
423 | + if self.model_parallel: |
424 | + torch.cuda.set_device(hidden_states.device) |
425 | + # Ensure layer_past is on same device as hidden_states (might not be correct) |
426 | + if layer_past is not None: |
427 | + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) |
428 | + # Ensure that attention_mask is always on the same device as hidden_states |
429 | + if attention_mask is not None: |
430 | + attention_mask = attention_mask.to(hidden_states.device) |
431 | + if isinstance(head_mask, torch.Tensor): |
432 | + head_mask = head_mask.to(hidden_states.device) |
433 | + if output_hidden_states: |
434 | + all_hidden_states = all_hidden_states + (hidden_states,) |
435 | + |
436 | + if self.gradient_checkpointing and self.training: |
437 | + |
438 | + if use_cache: |
439 | + logger.warning( |
440 | + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
441 | + ) |
442 | + use_cache = False |
443 | + |
444 | + def create_custom_forward(module): |
445 | + def custom_forward(*inputs): |
446 | + # None for past_key_value |
447 | + return module(*inputs, use_cache, output_attentions) |
448 | + |
449 | + return custom_forward |
450 | + |
451 | + outputs = torch.utils.checkpoint.checkpoint( |
452 | + create_custom_forward(block), |
453 | + hidden_states, |
454 | + None, |
455 | + attention_mask, |
456 | + head_mask[i], |
457 | + encoder_hidden_states, |
458 | + encoder_attention_mask, |
459 | + ) |
460 | + else: |
461 | + outputs = block( |
462 | + hidden_states, |
463 | + layer_past=layer_past, |
464 | + attention_mask=attention_mask, |
465 | + head_mask=head_mask[i], |
466 | + encoder_hidden_states=encoder_hidden_states, |
467 | + encoder_attention_mask=encoder_attention_mask, |
468 | + use_cache=use_cache, |
469 | + output_attentions=output_attentions, |
470 | + ) |
471 | + |
472 | + hidden_states = outputs[0] |
473 | + if use_cache is True: |
474 | + presents = presents + (outputs[1],) |
475 | + |
476 | + if output_attentions: |
477 | + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) |
478 | + if self.config.add_cross_attention: |
479 | + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) |
480 | + |
481 | + # Model Parallel: If it's the last layer for that device, put things on the next device |
482 | + if self.model_parallel: |
483 | + for k, v in self.device_map.items(): |
484 | + if i == v[-1] and "cuda:" + str(k) != self.last_device: |
485 | + hidden_states = hidden_states.to("cuda:" + str(k + 1)) |
486 | + |
487 | + hidden_states = self.ln_f(hidden_states) |
488 | + |
489 | + hidden_states = hidden_states.view(output_shape) |
490 | + # Add last hidden state |
491 | + if output_hidden_states: |
492 | + all_hidden_states = all_hidden_states + (hidden_states,) |
493 | + |
494 | + if not return_dict: |
495 | + return tuple( |
496 | + v |
497 | + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] |
498 | + if v is not None |
499 | + ) |
500 | + |
501 | + return BaseModelOutputWithPastAndCrossAttentions( |
502 | + last_hidden_state=hidden_states, |
503 | + past_key_values=presents, |
504 | + hidden_states=all_hidden_states, |
505 | + attentions=all_self_attentions, |
506 | + cross_attentions=all_cross_attentions, |
507 | + ) |
508 | + |
509 | |
510 | class GPT2LMHeadCustomModel(GPT2LMHeadModel): |
511 | config_class = GPT2CustomConfig |