Instructions to use BucketOfFish/simplified_phi2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use BucketOfFish/simplified_phi2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="BucketOfFish/simplified_phi2", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("BucketOfFish/simplified_phi2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use BucketOfFish/simplified_phi2 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "BucketOfFish/simplified_phi2" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "BucketOfFish/simplified_phi2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/BucketOfFish/simplified_phi2
- SGLang
How to use BucketOfFish/simplified_phi2 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "BucketOfFish/simplified_phi2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "BucketOfFish/simplified_phi2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "BucketOfFish/simplified_phi2" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "BucketOfFish/simplified_phi2", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use BucketOfFish/simplified_phi2 with Docker Model Runner:
docker model run hf.co/BucketOfFish/simplified_phi2
| from dataclasses import dataclass, field | |
| from einops import rearrange, repeat | |
| import math | |
| import torch | |
| from torch.amp.autocast_mode import autocast | |
| import torch.nn as nn | |
| from transformers.activations import ACT2FN | |
| from typing import cast | |
| # if flash_attn exists | |
| try: | |
| from flash_attn.bert_padding import pad_input, unpad_input | |
| from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding | |
| from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention | |
| from flash_attn.ops.fused_dense import FusedDense | |
| except ImportError: | |
| print("flash_attn not found, using default implementations") | |
| pad_input = unpad_input = FlashRotaryEmbedding = FlashCrossAttentio = FlashSelfAttention = FusedDense = None | |
| class RotaryEmbedding(nn.Module): | |
| """Rotary positional embedding (RoPE). See https://www.youtube.com/watch?v=C6rV8BsrrCc""" | |
| def __init__( | |
| self, | |
| d_rotary: int, | |
| rotary_base: float = 10000.0, | |
| initial_cos_sin_cache_len: int = 2048, | |
| device: torch.device | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self.d_rotary = d_rotary | |
| self.rotary_base = rotary_base | |
| self.device = device | |
| self.dtype = torch.float32 | |
| self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len) | |
| def _update_cos_sin_cache( | |
| self, | |
| seqlen: int, | |
| device: str | None = None, | |
| dtype: torch.dtype | None = None, | |
| ) -> None: | |
| # only call this function when seqlen is larger than _max_seqlen | |
| self._max_seqlen = seqlen | |
| # m * theta_i = m * base^(-2i/d) = m * (1 / base^(2i/d)), where i in [1, d/2] | |
| m = torch.arange( | |
| seqlen, | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| theta_i = 1.0 / ( | |
| self.rotary_base ** ( | |
| torch.arange( | |
| start=0, | |
| end=self.d_rotary, | |
| step=2, | |
| device=device, | |
| dtype=torch.float32, | |
| ) / self.d_rotary | |
| ) | |
| ) | |
| # torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp | |
| # TODO: does this matter if I'm disabling torch.autocast? | |
| m_theta_i = torch.outer(m, theta_i) | |
| self._cos_cached = torch.cos(m_theta_i).to(dtype) | |
| self._sin_cached = torch.sin(m_theta_i).to(dtype) | |
| # TODO: scale_base caching is labelled as not yet done in Phi2 | |
| """ | |
| if scale_base is not None: | |
| scale = ( | |
| torch.arange( | |
| start=0, | |
| end=self.d_rotary, | |
| step=2, | |
| device=self.device, | |
| dtype=torch.float32, | |
| ) + 0.4 * self.d_rotary | |
| ) / (1.4 * self.d_rotary) | |
| power = ( | |
| torch.arange(seqlen, dtype=scale.dtype, device=scale.device) - seqlen // 2 | |
| ) / scale_base | |
| scale = scale.to(device=power.device) ** rearrange(power, "s -> s 1") | |
| self._cos_cached = (torch.cos(m_theta_i) * scale).to(dtype) | |
| self._sin_cached = (torch.sin(m_theta_i) * scale).to(dtype) | |
| """ | |
| def _apply_rotary_emb_qkv( | |
| self, | |
| x: torch.FloatTensor, # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head) | |
| cos: torch.FloatTensor, # dim: (_max_seqlen, d_rotary) | |
| sin: torch.FloatTensor, # dim: (_max_seqlen, d_rotary) | |
| ) -> torch.FloatTensor: | |
| seqlen = x.shape[1] | |
| x_to_rotate = x[..., :self.d_rotary] | |
| x_to_keep_unrotated = x[..., self.d_rotary:] | |
| x1, x2 = x_to_rotate.chunk(2, dim=-1) # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_rotary/2) | |
| broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d" | |
| c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange) | |
| x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] # make sure rotary embedding is in float32 | |
| x_rotated = cast( | |
| torch.FloatTensor, | |
| torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype) | |
| ) | |
| return torch.cat([x_rotated, x_to_keep_unrotated], axis=-1) | |
| def forward( | |
| self, | |
| x: torch.FloatTensor, # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head) | |
| seqlen_offset: int = 0, # each sequence is shifted by this amount - used in inference with KV cache | |
| ) -> torch.FloatTensor: | |
| if ( | |
| not self._max_seqlen | |
| or self._max_seqlen < x.shape[1] + seqlen_offset | |
| or self._cos_cached.device != x.device | |
| or self._cos_cached.dtype != x.dtype | |
| or (self.training and self._cos_cached.is_inference()) | |
| ): | |
| self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset, device=x.device, dtype=x.dtype) | |
| return self._apply_rotary_emb_qkv( | |
| x, | |
| cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]), | |
| cast(torch.FloatTensor, self._sin_cached[seqlen_offset:]), | |
| ) | |
| class SelfAttention(nn.Module): | |
| def __init__( | |
| self, | |
| qk_scale: float | None = None, # will use 1/sqrt(d) if set to None | |
| attention_dropout: float = 0.0, | |
| ) -> None: | |
| super().__init__() | |
| self.qk_scale = qk_scale | |
| self.dropout = nn.Dropout(attention_dropout) | |
| # autocast is manually disabled to avoid `torch.einsum` using float16, which might lead to overflow | |
| def forward( | |
| self, | |
| qkv: torch.FloatTensor, # dim: (batch_size, seqlen, 3, n_heads, d_head) | |
| causal: bool = True, | |
| key_padding_mask: torch.BoolTensor | None = None, | |
| ) -> torch.FloatTensor: | |
| batch_size, seqlen = qkv.shape[0], qkv.shape[1] | |
| q, k, v = qkv.unbind(dim=2) | |
| q = q.to(torch.float32) | |
| k = k.to(torch.float32) | |
| qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1]) | |
| scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale) | |
| if key_padding_mask: | |
| padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) | |
| padding_mask.masked_fill_(key_padding_mask, 0.0) | |
| scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") | |
| if causal: | |
| causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) | |
| scores = scores + causal_mask.to(dtype=scores.dtype) | |
| attention = torch.softmax(scores, dim=-1).to(v.dtype) | |
| attention = self.dropout(attention) | |
| output = torch.einsum("bhts,bshd->bthd", attention, v) # dim: (batch_size, seqlen, n_heads, d_head) | |
| return cast(torch.FloatTensor, output) | |
| class CrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| qk_scale: float | None = None, # will use 1/sqrt(d) if set to None | |
| attention_dropout: float = 0.0, | |
| ) -> None: | |
| super().__init__() | |
| self.qk_scale = qk_scale | |
| self.dropout = nn.Dropout(attention_dropout) | |
| # autocast is manually disabled to avoid `torch.einsum` using float16, which might lead to overflow | |
| def forward( | |
| self, | |
| q: torch.FloatTensor, # dim: (batch_size, seqlen_q, n_heads, d_head) | |
| kv: torch.FloatTensor, # dim: (batch_size, seqlen_kv, 2, n_heads, d_head) | |
| causal: bool = True, | |
| key_padding_mask: torch.BoolTensor | None = None, | |
| ) -> torch.FloatTensor: | |
| batch_size, seqlen_q = q.shape[0], q.shape[1] | |
| seqlen_k = kv.shape[1] | |
| if kv.shape[3] != q.shape[2]: # repeat kv n_heads dim to match q n_heads | |
| kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) | |
| k, v = kv.unbind(dim=2) | |
| q = cast(torch.FloatTensor, q.to(torch.float32)) | |
| k = k.to(torch.float32) | |
| qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1]) | |
| scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale) | |
| if key_padding_mask: | |
| padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device) | |
| padding_mask.masked_fill_(key_padding_mask, 0.0) | |
| scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") | |
| if causal: | |
| rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1") | |
| cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long) | |
| causal_mask = cols > rows + seqlen_k - seqlen_q | |
| scores = scores.masked_fill(causal_mask, -10000.0) | |
| attention = torch.softmax(scores, dim=-1).to(v.dtype) | |
| attention = self.dropout(attention) | |
| output = torch.einsum("bhts,bshd->bthd", attention, v) # dim: (batch_size, seqlen_q, n_heads, d_head) | |
| return cast(torch.FloatTensor, output) | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| d_embedding: int, | |
| act_fn: str = "gelu_new", | |
| ) -> None: | |
| super().__init__() | |
| n_inner = 4 * d_embedding | |
| self.fc1 = nn.Linear(d_embedding, n_inner) | |
| self.act = ACT2FN[act_fn] | |
| self.fc2 = nn.Linear(n_inner, d_embedding) | |
| def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.fc2(x) | |
| return x | |
| class KVCache: | |
| """Options for model to calculate and store context during inference.""" | |
| max_seqlen: int | |
| max_batch_size: int | |
| seqlen_offset: int | |
| batch_size_offset: int | |
| kv_block_map: dict[int, torch.Tensor] = field(default_factory=dict) | |
| lengths_per_sample: torch.Tensor | None = None | |
| class MHA(nn.Module): | |
| """Multi-head attention block.""" | |
| def __init__( | |
| self, | |
| d_embedding: int, | |
| n_attn_heads: int, | |
| block_n: int, | |
| initial_cos_sin_cache_len: int, # length of cache for rotary embedding | |
| attn_pdrop: float, | |
| use_flash_rotary: bool, # use flash rotary embedding if possible | |
| use_flash_attn: bool, # use flash attention if possible | |
| use_fused_dense: bool, # use fused dense layer if possible | |
| checkpointing: bool, # torch.utils.checkpoint | |
| ) -> None: | |
| super().__init__() | |
| # rotary embedding | |
| rotary_cls = ( | |
| FlashRotaryEmbedding | |
| if use_flash_rotary and FlashRotaryEmbedding is not None | |
| else RotaryEmbedding | |
| ) | |
| self.rotary_emb = rotary_cls( | |
| # d_rotary=math.ceil((d_embedding // n_attn_heads) / 2), # d_rotary is half of d_head | |
| d_rotary=32, # TODO: figure out why Phi2 uses this | |
| initial_cos_sin_cache_len=initial_cos_sin_cache_len, | |
| ) | |
| # self attention | |
| self_attn_cls = ( | |
| FlashSelfAttention | |
| if use_flash_attn and FlashSelfAttention is not None | |
| else SelfAttention | |
| ) | |
| self.inner_self_attn = self_attn_cls(attention_dropout=attn_pdrop) | |
| # cross attention | |
| cross_attn_cls = ( | |
| FlashCrossAttention | |
| if use_flash_attn and FlashCrossAttention is not None | |
| else CrossAttention | |
| ) | |
| self.inner_cross_attn = cross_attn_cls(attention_dropout=attn_pdrop) | |
| # MLP | |
| self.n_attn_heads = n_attn_heads | |
| self.d_head = d_embedding // n_attn_heads | |
| linear_cls = ( | |
| FusedDense | |
| if use_fused_dense and FusedDense is not None | |
| else nn.Linear | |
| ) | |
| self.Wqkv = linear_cls( | |
| d_embedding, | |
| self.d_head * (3 * self.n_attn_heads), # calculating q, k, v for all heads in block simultaneously | |
| ) | |
| self.fc_out = linear_cls(d_embedding, d_embedding) | |
| # settings | |
| self.using_flash_attn = self_attn_cls is FlashSelfAttention | |
| self.block_n = block_n | |
| self.checkpointing = checkpointing | |
| def _forward_self_attn( | |
| self, | |
| qkv: torch.FloatTensor, # dim: (batch_size, seqlen, 3, n_heads, d_head) | |
| key_padding_mask: torch.BoolTensor | None, | |
| ) -> torch.FloatTensor: | |
| qkv = cast( | |
| torch.FloatTensor, | |
| torch.cat( | |
| [ | |
| self.rotary_emb(qkv[:, :, :2, :, :]), # qk | |
| qkv[:, :, 2, :, :], # v | |
| ], | |
| dim=2, | |
| ) | |
| ) | |
| if self.using_flash_attn and unpad_input and pad_input: # not touching flash attention code | |
| batch_size, seqlen = qkv.shape[0], qkv.shape[1] | |
| cu_seqlens, max_seqlen, indices = None, None, None | |
| # unpad input and retrieve `cu_seqlens` and `max_seqlen` to be used by `flash-attn` | |
| if key_padding_mask: | |
| qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask) | |
| if self.checkpointing: | |
| attn_output = torch.utils.checkpoint.checkpoint( | |
| self.inner_self_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen | |
| ) | |
| else: | |
| attn_output = self.inner_self_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device) | |
| # repad output | |
| if key_padding_mask: | |
| return pad_input(attn_output, indices, batch_size, seqlen) | |
| else: | |
| return attn_output | |
| if self.checkpointing: | |
| return torch.utils.checkpoint.checkpoint(self.inner_self_attn, qkv, key_padding_mask=key_padding_mask) | |
| else: | |
| return self.inner_self_attn(qkv, key_padding_mask=key_padding_mask) | |
| def _update_kv_cache( | |
| self, | |
| kv: torch.FloatTensor, # dim: (batch_size, seqlen, 2, n_heads, d_head) | |
| kv_cache: KVCache, | |
| block_n: int, | |
| ) -> None: | |
| if block_n not in kv_cache.kv_block_map: | |
| kv_cache.kv_block_map[block_n] = torch.empty( | |
| kv_cache.max_batch_size, | |
| kv_cache.max_seqlen, | |
| 2, | |
| kv.shape[-2], # n_heads | |
| kv.shape[-1], # d_head | |
| dtype=kv.dtype, | |
| device=kv.device, | |
| ) | |
| batch_start = kv_cache.batch_size_offset | |
| batch_end = batch_start + kv.shape[0] | |
| sequence_start = kv_cache.seqlen_offset | |
| sequence_end = sequence_start + kv.shape[1] | |
| # TODO: figure out why they're doing this | |
| if sequence_end >= kv_cache.max_seqlen: | |
| kv_cache.kv_block_map[block_n] = torch.concatenate( | |
| (kv_cache.kv_block_map[block_n], kv), | |
| dim=1, | |
| ) | |
| kv_cache.kv_block_map[block_n][ | |
| batch_start:batch_end, | |
| sequence_start:sequence_end, | |
| ... | |
| ] = kv | |
| kv = kv_cache.kv_block_map[block_n][ | |
| batch_start:batch_end, | |
| :sequence_end, | |
| ... | |
| ] | |
| return kv | |
| def _forward_cross_attn( | |
| self, | |
| qkv: torch.FloatTensor, # dim: (batch_size, seqlen, 3, n_heads, d_head) | |
| kv_cache: KVCache, | |
| key_padding_mask: torch.BoolTensor | None, | |
| ) -> torch.FloatTensor: | |
| qk = qkv[:, :, :2, :, :] | |
| qk = self.rotary_emb( | |
| qk, | |
| seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset, | |
| ) | |
| v = cast(torch.FloatTensor, qkv[:, :, 2, :, :]) | |
| q = qk[:, :, 0, :, :] | |
| kv = torch.cat( | |
| [ | |
| qk[:, :, 1, :, :].unsqueeze(2), | |
| v.unsqueeze(2), | |
| ], | |
| dim=2, | |
| ) | |
| kv = self._update_kv_cache(kv, kv_cache, self.block_n) | |
| causal = (kv_cache.seqlen_offset == 0) | |
| if self.using_flash_attn and unpad_input and pad_input: # not touching flash attention code | |
| batch_size, seqlen_q = q.shape[0], q.shape[1] | |
| seqlen_k = kv.shape[1] | |
| cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, indices_q = ( | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| # unpad input and retrieve `cu_seqlens` and `max_seqlen` to be used by `flash-attn` | |
| if key_padding_mask: | |
| kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask) | |
| if seqlen_q == 1: | |
| key_padding_mask = cast(torch.BoolTensor, torch.ones(batch_size, 1, device=q.device)) | |
| elif seqlen_q != seqlen_k: | |
| key_padding_mask = cast(torch.BoolTensor, key_padding_mask[:, -seqlen_q:]) | |
| q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask) | |
| if self.checkpointing: | |
| attn_output = torch.utils.checkpoint.checkpoint( | |
| self.inner_cross_attn, | |
| q, | |
| kv, | |
| causal=causal, | |
| cu_seqlens=cu_seqlens_q, | |
| max_seqlen=max_seqlen_q, | |
| cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_k=max_seqlen_k, | |
| ) | |
| else: | |
| attn_output = self.inner_cross_attn( | |
| q, | |
| kv, | |
| causal=causal, | |
| cu_seqlens=cu_seqlens_q, | |
| max_seqlen=max_seqlen_q, | |
| cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_k=max_seqlen_k, | |
| ) | |
| if key_padding_mask: | |
| return pad_input(attn_output, indices_q, batch_size, max_seqlen_q) | |
| else: | |
| return attn_output | |
| if self.checkpointing: | |
| return torch.utils.checkpoint.checkpoint( | |
| self.inner_cross_attn, | |
| q, | |
| kv, | |
| key_padding_mask=key_padding_mask, | |
| causal=causal, | |
| ) | |
| else: | |
| return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal) | |
| def forward( | |
| self, | |
| x: torch.FloatTensor, # dim: (batch_size, seqlen, d_embedding) | |
| kv_cache: KVCache | None = None, | |
| key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None, | |
| ) -> tuple[torch.FloatTensor, torch.FloatTensor]: | |
| if key_padding_mask is not None: | |
| key_padding_mask = cast(torch.BoolTensor, key_padding_mask.bool()) # make sure it's bool and not int | |
| qkv = self.Wqkv(x) # dim: (batch_size, seqlen, 3*n_heads*d_head) | |
| qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.d_head) # dim: (batch_size, seqlen, 3, n_heads, d_head) | |
| if kv_cache is None: | |
| attn_output = self._forward_self_attn(qkv, key_padding_mask) | |
| else: | |
| attn_output = self._forward_cross_attn(qkv, kv_cache, key_padding_mask) | |
| output = rearrange(attn_output, "... h d -> ... (h d)") | |
| output = self.fc_out(output) | |
| return output | |
| class ParallelAttentionBlock(nn.Module): | |
| """Calculates attention and MLP in parallel.""" | |
| def __init__( | |
| self, | |
| resid_pdrop: float, # a bit of a misnomer, right? | |
| layer_norm_epsilon: float, | |
| d_embedding: int, | |
| n_attn_heads: int, | |
| block_n: int, | |
| initial_cos_sin_cache_len: int, # length of cache for rotary embedding | |
| attn_pdrop: float, | |
| use_flash_rotary: bool = True, # use flash rotary embedding if possible | |
| use_flash_attn: bool = True, # use flash attention if possible | |
| use_fused_dense: bool = True, # use fused dense layer if possible | |
| checkpointing: bool = False, # torch.utils.checkpoint | |
| ) -> None: | |
| super().__init__() | |
| self.layer_norm = nn.LayerNorm(d_embedding, eps=layer_norm_epsilon) | |
| self.block_n = block_n | |
| self.multi_head_attention = MHA( | |
| d_embedding=d_embedding, | |
| n_attn_heads=n_attn_heads, | |
| block_n=block_n, | |
| initial_cos_sin_cache_len=initial_cos_sin_cache_len, | |
| attn_pdrop=attn_pdrop, | |
| use_flash_rotary=use_flash_rotary, | |
| use_flash_attn=use_flash_attn, | |
| use_fused_dense=use_fused_dense, | |
| checkpointing=checkpointing, | |
| ) | |
| self.mlp = MLP(d_embedding) | |
| self.dropout = nn.Dropout(resid_pdrop) | |
| def forward( | |
| self, | |
| x: torch.FloatTensor, # dim: (batch_size, seq_len, d_embedding) | |
| kv_cache: KVCache | None = None, | |
| key_padding_mask: torch.BoolTensor | None = None, | |
| ) -> torch.FloatTensor: | |
| residual = x | |
| x = self.layer_norm(x) # each token (dim: d_embedding) is normalized individually | |
| attn_outputs = self.multi_head_attention( | |
| x, | |
| kv_cache=kv_cache, | |
| key_padding_mask=key_padding_mask, | |
| ) | |
| mlp_outputs = self.mlp(x) | |
| x = self.dropout(attn_outputs + mlp_outputs) + residual | |
| return x | |