Text Generation
Transformers
Safetensors
English
echo_hybrid
trl
fft
rnn
ssm
conversational
custom_code
Instructions to use mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1
- SGLang
How to use mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1 with Docker Model Runner:
docker model run hf.co/mrs83/Kurtis-EON1-Hybrid-0.7B-v0.1.1
| import torch | |
| import triton | |
| import triton.language as tl | |
| # ────────────────────────────────────────────────────────────── | |
| # FORWARD PASS KERNELS | |
| # ────────────────────────────────────────────────────────────── | |
| def fwd_accumulate_kernel( | |
| a_ptr, | |
| b_ptr, | |
| chunk_a_ptr, | |
| chunk_c_ptr, | |
| T, | |
| D, | |
| stride_a_b, | |
| stride_a_t, | |
| stride_a_d, | |
| stride_b_b, | |
| stride_b_t, | |
| stride_b_d, | |
| BLOCK_SIZE_D: tl.constexpr, | |
| BLOCK_SIZE_T: tl.constexpr, | |
| ): | |
| pid_b = tl.program_id(0) | |
| pid_d = tl.program_id(1) | |
| pid_t = tl.program_id(2) | |
| d_offsets = pid_d * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D) | |
| d_mask = d_offsets < D | |
| # Chunk boundaries | |
| t_start = pid_t * BLOCK_SIZE_T | |
| # Initialize local carries | |
| a_acc = tl.full((BLOCK_SIZE_D,), 1.0, dtype=tl.float32) | |
| c_acc = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) | |
| a_base = a_ptr + pid_b * stride_a_b + d_offsets * stride_a_d | |
| b_base = b_ptr + pid_b * stride_b_b + d_offsets * stride_b_d | |
| for i in range(BLOCK_SIZE_T): | |
| t = t_start + i | |
| if t < T: | |
| a = tl.load(a_base + t * stride_a_t, mask=d_mask, other=1.0).to(tl.float32) | |
| b = tl.load(b_base + t * stride_b_t, mask=d_mask, other=0.0).to(tl.float32) | |
| # Combine: (a_acc, c_acc) o (a, b) = (a * a_acc, a * c_acc + b) | |
| c_acc = a * c_acc + b | |
| a_acc = a * a_acc | |
| # Store chunk summaries | |
| # chunk_ptr: [B, num_chunks, D] | |
| num_chunks = (T + BLOCK_SIZE_T - 1) // BLOCK_SIZE_T | |
| summary_idx = pid_b * (num_chunks * D) + pid_t * D + d_offsets | |
| tl.store(chunk_a_ptr + summary_idx, a_acc, mask=d_mask) | |
| tl.store(chunk_c_ptr + summary_idx, c_acc, mask=d_mask) | |
| def fwd_global_scan_kernel( | |
| chunk_a_ptr, | |
| chunk_c_ptr, | |
| chunk_carries_ptr, | |
| c_0_ptr, | |
| num_chunks, | |
| D, | |
| stride_c0_b, | |
| stride_c0_d, | |
| HAS_C_0: tl.constexpr, | |
| BLOCK_SIZE_D: tl.constexpr, | |
| ): | |
| pid_b = tl.program_id(0) | |
| pid_d = tl.program_id(1) | |
| d_offsets = pid_d * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D) | |
| d_mask = d_offsets < D | |
| # Initial carry | |
| carry = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) | |
| if HAS_C_0: | |
| c0_ptrs = c_0_ptr + pid_b * stride_c0_b + d_offsets * stride_c0_d | |
| carry = tl.load(c0_ptrs, mask=d_mask, other=0.0).to(tl.float32) | |
| # Base pointers for chunk summaries | |
| chunk_base = pid_b * (num_chunks * D) + d_offsets | |
| for j in range(num_chunks): | |
| # Store carry into chunk j (this is c_{j-1}) | |
| tl.store(chunk_carries_ptr + chunk_base + j * D, carry, mask=d_mask) | |
| # Load chunk summary | |
| a_sum = tl.load(chunk_a_ptr + chunk_base + j * D, mask=d_mask, other=1.0).to(tl.float32) | |
| c_sum = tl.load(chunk_c_ptr + chunk_base + j * D, mask=d_mask, other=0.0).to(tl.float32) | |
| # Update carry for chunk j+1 | |
| carry = a_sum * carry + c_sum | |
| def fwd_combine_kernel( | |
| a_ptr, | |
| b_ptr, | |
| chunk_carries_ptr, | |
| c_out_ptr, | |
| T, | |
| D, | |
| stride_a_b, | |
| stride_a_t, | |
| stride_a_d, | |
| stride_b_b, | |
| stride_b_t, | |
| stride_b_d, | |
| stride_c_b, | |
| stride_c_t, | |
| stride_c_d, | |
| BLOCK_SIZE_D: tl.constexpr, | |
| BLOCK_SIZE_T: tl.constexpr, | |
| ): | |
| pid_b = tl.program_id(0) | |
| pid_d = tl.program_id(1) | |
| pid_t = tl.program_id(2) | |
| d_offsets = pid_d * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D) | |
| d_mask = d_offsets < D | |
| num_chunks = (T + BLOCK_SIZE_T - 1) // BLOCK_SIZE_T | |
| t_start = pid_t * BLOCK_SIZE_T | |
| # Load initial carry for this chunk | |
| carry_idx = pid_b * (num_chunks * D) + pid_t * D + d_offsets | |
| carry = tl.load(chunk_carries_ptr + carry_idx, mask=d_mask, other=0.0).to(tl.float32) | |
| a_base = a_ptr + pid_b * stride_a_b + d_offsets * stride_a_d | |
| b_base = b_ptr + pid_b * stride_b_b + d_offsets * stride_b_d | |
| c_out_base = c_out_ptr + pid_b * stride_c_b + d_offsets * stride_c_d | |
| for i in range(BLOCK_SIZE_T): | |
| t = t_start + i | |
| if t < T: | |
| a = tl.load(a_base + t * stride_a_t, mask=d_mask, other=1.0).to(tl.float32) | |
| b = tl.load(b_base + t * stride_b_t, mask=d_mask, other=0.0).to(tl.float32) | |
| carry = a * carry + b | |
| tl.store(c_out_base + t * stride_c_t, carry, mask=d_mask) | |
| # ────────────────────────────────────────────────────────────── | |
| # BACKWARD PASS KERNELS | |
| # ────────────────────────────────────────────────────────────── | |
| def bwd_accumulate_kernel( | |
| a_ptr, | |
| grad_c_out_ptr, | |
| chunk_a_prod_ptr, | |
| chunk_g_sum_ptr, | |
| T, | |
| D, | |
| stride_a_b, | |
| stride_a_t, | |
| stride_a_d, | |
| stride_g_b, | |
| stride_g_t, | |
| stride_g_d, | |
| BLOCK_SIZE_D: tl.constexpr, | |
| BLOCK_SIZE_T: tl.constexpr, | |
| ): | |
| pid_b = tl.program_id(0) | |
| pid_d = tl.program_id(1) | |
| pid_t = tl.program_id(2) | |
| d_offsets = pid_d * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D) | |
| d_mask = d_offsets < D | |
| t_start = pid_t * BLOCK_SIZE_T | |
| t_end = tl.minimum(t_start + BLOCK_SIZE_T, T) | |
| a_prod = tl.full((BLOCK_SIZE_D,), 1.0, dtype=tl.float32) | |
| g_sum = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) | |
| a_base = a_ptr + pid_b * stride_a_b + d_offsets * stride_a_d | |
| g_base = grad_c_out_ptr + pid_b * stride_g_b + d_offsets * stride_g_d | |
| # Reverse sequential accumulation for chunk summary | |
| # grad_c_start = (g_start + a_start+1*g_start+1 + ...) + (a_start+1*...*a_end) * grad_c_end | |
| # We iterate from t_end-1 down to t_start | |
| for i in range(t_end - t_start - 1, -1, -1): | |
| t = t_start + i | |
| g = tl.load(g_base + t * stride_g_t, mask=d_mask, other=0.0).to(tl.float32) | |
| # Multiplier is a_{t+1}. If t is T-1, multiplier is 1.0 (or 0 if we assume grad_c_T=0) | |
| # Actually, for the very last token in sequence, grad_c_T is 0. | |
| a_next = tl.full((BLOCK_SIZE_D,), 1.0, dtype=tl.float32) | |
| if t + 1 < T: | |
| a_next = tl.load(a_base + (t + 1) * stride_a_t, mask=d_mask, other=1.0).to(tl.float32) | |
| # combine: g_sum = g + a_next * g_sum, a_prod = a_next * a_prod | |
| g_sum = g + a_next * g_sum | |
| a_prod = a_next * a_prod | |
| num_chunks = (T + BLOCK_SIZE_T - 1) // BLOCK_SIZE_T | |
| summary_idx = pid_b * (num_chunks * D) + pid_t * D + d_offsets | |
| tl.store(chunk_a_prod_ptr + summary_idx, a_prod, mask=d_mask) | |
| tl.store(chunk_g_sum_ptr + summary_idx, g_sum, mask=d_mask) | |
| def bwd_global_scan_kernel( | |
| chunk_a_prod_ptr, | |
| chunk_g_sum_ptr, | |
| chunk_grad_carries_ptr, | |
| num_chunks, | |
| D, | |
| BLOCK_SIZE_D: tl.constexpr, | |
| ): | |
| pid_b = tl.program_id(0) | |
| pid_d = tl.program_id(1) | |
| d_offsets = pid_d * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D) | |
| d_mask = d_offsets < D | |
| grad_carry = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) | |
| chunk_base = pid_b * (num_chunks * D) + d_offsets | |
| # Scan from last chunk to first | |
| for j in range(num_chunks - 1, -1, -1): | |
| # Store carry into chunk j (this is grad_c_{chunk_j_end}) | |
| tl.store(chunk_grad_carries_ptr + chunk_base + j * D, grad_carry, mask=d_mask) | |
| a_prod = tl.load(chunk_a_prod_ptr + chunk_base + j * D, mask=d_mask, other=1.0).to( | |
| tl.float32 | |
| ) | |
| g_sum = tl.load(chunk_g_sum_ptr + chunk_base + j * D, mask=d_mask, other=0.0).to(tl.float32) | |
| # Update carry for chunk j-1 | |
| # grad_c_{t_start_of_chunk_j} = g_sum_chunk_j + a_prod_chunk_j * grad_c_{t_end_of_chunk_j} | |
| grad_carry = g_sum + a_prod * grad_carry | |
| def bwd_combine_kernel( | |
| a_ptr, | |
| c_out_ptr, | |
| c_0_ptr, | |
| grad_c_out_ptr, | |
| chunk_grad_carries_ptr, | |
| grad_a_ptr, | |
| grad_b_ptr, | |
| grad_c_0_ptr, | |
| T, | |
| D, | |
| stride_a_b, | |
| stride_a_t, | |
| stride_a_d, | |
| stride_c_b, | |
| stride_c_t, | |
| stride_c_d, | |
| stride_g_b, | |
| stride_g_t, | |
| stride_g_d, | |
| stride_gb_b, | |
| stride_gb_t, | |
| stride_gb_d, | |
| stride_c0_b, | |
| stride_c0_d, | |
| HAS_C_0: tl.constexpr, | |
| BLOCK_SIZE_D: tl.constexpr, | |
| BLOCK_SIZE_T: tl.constexpr, | |
| ): | |
| pid_b = tl.program_id(0) | |
| pid_d = tl.program_id(1) | |
| pid_t = tl.program_id(2) | |
| d_offsets = pid_d * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D) | |
| d_mask = d_offsets < D | |
| num_chunks = (T + BLOCK_SIZE_T - 1) // BLOCK_SIZE_T | |
| t_start = pid_t * BLOCK_SIZE_T | |
| t_end = tl.minimum(t_start + BLOCK_SIZE_T, T) | |
| # Load initial gradient carry (this is grad_c_{t_end}) | |
| # This was computed as grad_c_end in Pass 2. | |
| grad_at_tend = tl.load( | |
| chunk_grad_carries_ptr + pid_b * (num_chunks * D) + pid_t * D + d_offsets, | |
| mask=d_mask, | |
| other=0.0, | |
| ).to(tl.float32) | |
| a_base = a_ptr + pid_b * stride_a_b + d_offsets * stride_a_d | |
| c_out_base = c_out_ptr + pid_b * stride_c_b + d_offsets * stride_c_d | |
| g_base = grad_c_out_ptr + pid_b * stride_g_b + d_offsets * stride_g_d | |
| ga_base = grad_a_ptr + pid_b * stride_a_b + d_offsets * stride_a_d | |
| gb_base = grad_b_ptr + pid_b * stride_gb_b + d_offsets * stride_gb_d | |
| # running_grad enters index t as a_{t+1} * grad_c_{t+1} | |
| # For the very last token in chunk t=t_end-1, we need a_{t_end} * grad_c_{t_end} | |
| a_tend = tl.full((BLOCK_SIZE_D,), 1.0, dtype=tl.float32) | |
| if t_end < T: | |
| a_tend = tl.load(a_base + t_end * stride_a_t, mask=d_mask, other=1.0).to(tl.float32) | |
| running_grad = a_tend * grad_at_tend | |
| # Reverse scan within chunk | |
| for i in range(t_end - t_start - 1, -1, -1): | |
| t = t_start + i | |
| g_out_t = tl.load(g_base + t * stride_g_t, mask=d_mask, other=0.0).to(tl.float32) | |
| # grad_c_t = g_out_t + a_{t+1} * grad_c_{t+1} | |
| # In our loop, running_grad is always (a_{t+1} * grad_c_{t+1}) | |
| grad_c_t = g_out_t + running_grad | |
| # Store results | |
| # grad_b_t = grad_c_t | |
| tl.store(gb_base + t * stride_gb_t, grad_c_t, mask=d_mask) | |
| # grad_a_t = c_{t-1} * grad_c_t | |
| c_prev = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) | |
| if t > 0: | |
| c_prev = tl.load(c_out_base + (t - 1) * stride_c_t, mask=d_mask, other=0.0).to( | |
| tl.float32 | |
| ) | |
| elif HAS_C_0: | |
| c_prev = tl.load( | |
| c_0_ptr + pid_b * stride_c0_b + d_offsets * stride_c0_d, mask=d_mask, other=0.0 | |
| ).to(tl.float32) | |
| tl.store(ga_base + t * stride_a_t, c_prev * grad_c_t, mask=d_mask) | |
| # update running_grad for the next iteration (t-1) | |
| # new running_grad = a_t * grad_c_t | |
| a_t = tl.load(a_base + t * stride_a_t, mask=d_mask, other=1.0).to(tl.float32) | |
| running_grad = a_t * grad_c_t | |
| # Final carry for d_c0 if pid_t == 0 | |
| if pid_t == 0 and HAS_C_0: | |
| # After loop for t=0, running_grad is a_0 * grad_c_0 | |
| tl.store( | |
| grad_c_0_ptr + pid_b * stride_c0_b + d_offsets * stride_c0_d, running_grad, mask=d_mask | |
| ) | |
| # ────────────────────────────────────────────────────────────── | |
| # PYTORCH WRAPPER | |
| # ────────────────────────────────────────────────────────────── | |
| class DSRNScanTriton(torch.autograd.Function): | |
| def forward(ctx, a, b, c_0=None): | |
| B, T, D = a.shape | |
| device = a.device | |
| a = a.contiguous() | |
| b = b.contiguous() | |
| if c_0 is not None: | |
| c_0 = c_0.contiguous() | |
| c_out = torch.empty_like(a) | |
| BLOCK_SIZE_T = 64 | |
| BLOCK_SIZE_D = triton.next_power_of_2(min(128, D)) | |
| num_chunks = (T + BLOCK_SIZE_T - 1) // BLOCK_SIZE_T | |
| # Temporary workspace | |
| chunk_a = torch.empty((B, num_chunks, D), device=device, dtype=torch.float32) | |
| chunk_c = torch.empty((B, num_chunks, D), device=device, dtype=torch.float32) | |
| chunk_carries = torch.empty((B, num_chunks, D), device=device, dtype=torch.float32) | |
| # Pass 1: Accumulate | |
| grid1 = (B, triton.cdiv(D, BLOCK_SIZE_D), num_chunks) | |
| fwd_accumulate_kernel[grid1]( | |
| a, | |
| b, | |
| chunk_a, | |
| chunk_c, | |
| T, | |
| D, | |
| a.stride(0), | |
| a.stride(1), | |
| a.stride(2), | |
| b.stride(0), | |
| b.stride(1), | |
| b.stride(2), | |
| BLOCK_SIZE_D, | |
| BLOCK_SIZE_T, | |
| ) | |
| # Pass 2: Global Scan | |
| grid2 = (B, triton.cdiv(D, BLOCK_SIZE_D)) | |
| fwd_global_scan_kernel[grid2]( | |
| chunk_a, | |
| chunk_c, | |
| chunk_carries, | |
| c_0, | |
| num_chunks, | |
| D, | |
| c_0.stride(0) if c_0 is not None else 0, | |
| c_0.stride(1) if c_0 is not None else 0, | |
| HAS_C_0=(c_0 is not None), | |
| BLOCK_SIZE_D=BLOCK_SIZE_D, | |
| ) | |
| # Pass 3: Combine | |
| fwd_combine_kernel[grid1]( | |
| a, | |
| b, | |
| chunk_carries, | |
| c_out, | |
| T, | |
| D, | |
| a.stride(0), | |
| a.stride(1), | |
| a.stride(2), | |
| b.stride(0), | |
| b.stride(1), | |
| b.stride(2), | |
| c_out.stride(0), | |
| c_out.stride(1), | |
| c_out.stride(2), | |
| BLOCK_SIZE_D, | |
| BLOCK_SIZE_T, | |
| ) | |
| ctx.save_for_backward(a, c_out, c_0) | |
| ctx.BLOCK_SIZE_T = BLOCK_SIZE_T | |
| ctx.BLOCK_SIZE_D = BLOCK_SIZE_D | |
| return c_out | |
| def backward(ctx, grad_c_out): | |
| a, c_out, c_0 = ctx.saved_tensors | |
| B, T, D = a.shape | |
| device = a.device | |
| grad_c_out = grad_c_out.contiguous() | |
| grad_a = torch.empty_like(a) | |
| grad_b = torch.empty_like(a) | |
| grad_c_0 = torch.zeros_like(c_0) if c_0 is not None else None | |
| BLOCK_SIZE_T = ctx.BLOCK_SIZE_T | |
| BLOCK_SIZE_D = ctx.BLOCK_SIZE_D | |
| num_chunks = (T + BLOCK_SIZE_T - 1) // BLOCK_SIZE_T | |
| chunk_grad_a = torch.empty((B, num_chunks, D), device=device, dtype=torch.float32) | |
| chunk_grad_x = torch.empty((B, num_chunks, D), device=device, dtype=torch.float32) | |
| chunk_grad_carries = torch.empty((B, num_chunks, D), device=device, dtype=torch.float32) | |
| grid1 = (B, triton.cdiv(D, BLOCK_SIZE_D), num_chunks) | |
| # Pass 1: Accumulate | |
| bwd_accumulate_kernel[grid1]( | |
| a, | |
| grad_c_out, | |
| chunk_grad_a, | |
| chunk_grad_x, | |
| T, | |
| D, | |
| a.stride(0), | |
| a.stride(1), | |
| a.stride(2), | |
| grad_c_out.stride(0), | |
| grad_c_out.stride(1), | |
| grad_c_out.stride(2), | |
| BLOCK_SIZE_D, | |
| BLOCK_SIZE_T, | |
| ) | |
| # Pass 2: Global Scan | |
| grid2 = (B, triton.cdiv(D, BLOCK_SIZE_D)) | |
| bwd_global_scan_kernel[grid2]( | |
| chunk_grad_a, chunk_grad_x, chunk_grad_carries, num_chunks, D, BLOCK_SIZE_D | |
| ) | |
| # Pass 3: Combine | |
| bwd_combine_kernel[grid1]( | |
| a, | |
| c_out, | |
| c_0, | |
| grad_c_out, | |
| chunk_grad_carries, | |
| grad_a, | |
| grad_b, | |
| grad_c_0, | |
| T, | |
| D, | |
| a.stride(0), | |
| a.stride(1), | |
| a.stride(2), | |
| c_out.stride(0), | |
| c_out.stride(1), | |
| c_out.stride(2), | |
| grad_c_out.stride(0), | |
| grad_c_out.stride(1), | |
| grad_c_out.stride(2), | |
| grad_b.stride(0), | |
| grad_b.stride(1), | |
| grad_b.stride(2), | |
| c_0.stride(0) if c_0 is not None else 0, | |
| c_0.stride(1) if c_0 is not None else 0, | |
| HAS_C_0=(c_0 is not None), | |
| BLOCK_SIZE_D=BLOCK_SIZE_D, | |
| BLOCK_SIZE_T=BLOCK_SIZE_T, | |
| ) | |
| return grad_a, grad_b, grad_c_0 | |
| def triton_dsrn_parallel_scan(g_t, m_t, c_0=None): | |
| orig_dtype = g_t.dtype | |
| a = (1.0 - g_t).float() | |
| b = (g_t * m_t).float() | |
| if c_0 is not None: | |
| c_0 = c_0.float() | |
| out = DSRNScanTriton.apply(a, b, c_0) | |
| return out.to(orig_dtype) | |