
    9jI                     J   d dl mZ d dlZd dlmc mZ d dlmZ d dlm	Z	 d dl
mZ d dlmZ d dlmZmZ d dlmZ d d	lmZmZmZ d d
lmZmZ d ZddZdej8                  fdZdej8                  fdZd Zd Z d Z!d Z" e!       Z# e"       Z$d Z% e%       Z&d Z' e'       Z(d Z)y)    )castN)_prims)DispatchKey)autograd_not_implemented)HigherOrderOperator)CUDARngStateHelpermake_contiguous_strides_for)FakeTensorMode)disable_proxy_modes_tracingProxyTorchDispatchModetrack_tensor_tree)_device_dtypec           	      h    t        d| j                   d| j                   d| j                   d      )Nz"You are trying to functionalize a z RNG operator but zE does not use Philox/counter-based RNG. Therefore, functionalizing a zo RNG operator is not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU.)RuntimeErrortype)devices    V/media/conek/DATA/Code/OCR/venv/lib/python3.12/site-packages/torch/_prims/rng_prims.pythrow_on_non_cudar      sB    

,V[[M9KFKK= YFFLkk] Si	i     c                    t         j                  j                  d| z   |d|      }|j                  |       t	        t         j
                  j                  j                  |       }|j                  }|r||_	        ||fD ]J  }	||	_
        t         j                  j                  j                  |	_        | |z   |	_        ||	_        ||	_        L y )Nz
rngprims:: )mutates_argsschema)torchlibrary	custom_opregister_fakegetattr_opsopsrngprimsdefault_tags__doc___prims_commonRETURN_TYPENEWreturn_typer   	impl_atenprim_meta_impl)
namer   r*   	impl_metadoctagsrngprim_defprim_packetprimps
             r   register_rng_primr4      s    --))tYR * K i(%**..1148KD
4  %	++77;;&=$%r   shapec                 r    t        j                  t        j                  dt        j                              S )Nr   dtype)r   
TensorLiker   tensorint64)r5   s    r   philox_rand_offset_metar<   3   s$     U\\!5;;?@@r   c                    d}| D ]  }||z  }	 t        j                  |t         j                        }d}d}d}t         j                  j	                  t         j                  j                               }|j                  |z  }t        t        |      }	|	|z   dz
  |z  }
t        |
|j                  |z        }
|	dz
  ||
z  |z  z  dz   |z  S )N   r7         )r   scalar_tensorr;   cudaget_device_propertiescurrent_devicemax_threads_per_multi_processorr   intminmulti_processor_count)r5   numel_scalardim_sizenumel
block_sizeunrollcurand4_engine_callsdevice_propertyblocks_per_smnum	grid_sizes              r   philox_rand_offsetrS   9   s     L ! !EKK@EJFjj66uzz7P7P7RSO#CCzQM
sE
Cz!A%*4IIDD}TUI1W*y069:Q>BVVVr   c                     d} d}dt         j                  dt         j                  dt         j                  dt        t        df   d z  dt
        d	t        fd
}dt         j                  dt         j                  dt         j                  dt        t        df   d z  dt
        d	t        fd}t        | |||dt         j                  j                  f       y )Nphilox_randz{(SymInt[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> (Tensor, Tensor)r5   seedoffsetstride.r   r8   c                     |t        d|       t        |       }t        j                  | |||      }t	        |       }||fS )Nstride must be None, got )r5   stridesr8   r   )AssertionErrorr	   r   
TensorMetar<   )r5   rV   rW   rX   r   r8   random_valuess          r   _philox_rand_metaz/register_philox_rand.<locals>._philox_rand_metaS   sW      #<VH!EFF,U3))uV
 )/v&&r   c                 ^   |t        d|       |j                  dk(  rg }n|g}|j                  dk7  rt        |      t        j                  j                  |      5  t        j                  ||       t        j                  | ||      }d d d        t        |       fS # 1 sw Y   xY w)NrZ   cpurB   )r   r8   )
r\   r   r   r   randomfork_rngr   set_torch_state_tensorrandrS   )r5   rV   rW   rX   r   r8   devicesr^   s           r   _philox_randz*register_philox_rand.<locals>._philox_rande   s      #<VH!EFF;;%GhG;;& #F++\\""7+ 	J55dFC!JJuV5IM	J 0777		J 	Js   /B##B,z$Philox based stateless rand operator)r,   r   r*   r-   r.   r/   )
r   SizeTensortuplerF   r   r   r4   Tagnondeterministic_seeded)r,   r   r_   rg   s       r   register_philox_randrm   O   s    D KF'zz'll' ' c3h$&	'
 ' '$8zz8ll8 8 c3h$&	8
 8 82 #2ii//1r   c                    |j                  d      rB|j                  d      }t        |t              rt        j                  |      }|j
                  S | D ch c]3  }t        |t        j                        s|j                  j
                  5 }}t        d |D              ryt        d |D              ryt        d |D              ryt        d |D              ry	y c c}w )
Nr   c              3   &   K   | ]	  }|d k(    yw)rB   Nr   .0devs     r   	<genexpr>zget_device.<locals>.<genexpr>   s     
,S3&=
,   rB   c              3   &   K   | ]	  }|d k(    yw)xpuNr   rp   s     r   rs   zget_device.<locals>.<genexpr>        -cSE\-rt   rv   c              3   &   K   | ]	  }|d k(    yw)hpuNr   rp   s     r   rs   zget_device.<locals>.<genexpr>   rw   rt   ry   c              3   &   K   | ]	  }|d k(    yw)ra   Nr   rp   s     r   rs   zget_device.<locals>.<genexpr>   rw   rt   ra   )get
isinstancestrr   r   r   ri   any)argskwargsr   argrf   s        r   
get_devicer      s    zz(H%fc"\\&)F{{*.P3*S%,,2OszzPGP

,G
,,	-W-	-	-W-	-	-W-	- Qs   C8Cc                      G d dt               }  |         j                  t        j                        t	        d             j                  t        j
                        d        j                  t        j                        d        j                  t        j                        d        j                  t        j                        d        j                  t        j                        fd	       j                  t              fd
       }j                  t              fd       }S )Nc                   (     e Zd Z fdZ fdZ xZS )>register_run_and_save_rng_state_op.<locals>.RunAndSaveRngStatec                 (    t         |   dd       y )Nrun_and_save_rng_stateT	cacheablesuper__init__self	__class__s    r   r   zGregister_run_and_save_rng_state_op.<locals>.RunAndSaveRngState.__init__   s    G5Fr   c                 *    t        |   |g|i |S Nr   __call__)r   opr   r   r   s       r   r   zGregister_run_and_save_rng_state_op.<locals>.RunAndSaveRngState.__call__   s    7#B8888r   __name__
__module____qualname__r   r   __classcell__r   s   @r   RunAndSaveRngStater      s    	G	9 	9r   r   Tdeferred_errorc                 N    t         j                  j                          | |i |fS r   )r   rB   get_rng_stater   r   r   s      r   	impl_cudaz5register_run_and_save_rng_state_op.<locals>.impl_cuda   s$    zz'')2t+>v+>>>r   c                 :    t        j                          | |i |fS r   )r   r   r   s      r   impl_cpuz4register_run_and_save_rng_state_op.<locals>.impl_cpu   s     ""$b$&9&&999r   c                     t        t        d      r&t        j                  j                          | |i |fS t	        d      Nry   z2functionalize a hpu RNG operator is not supported.)hasattrr   ry   r   r   r   s      r   impl_hpuz4register_run_and_save_rng_state_op.<locals>.impl_hpu   s:    5% 99**,b$.A&.AAAOPPr   c                 N    t         j                  j                          | |i |fS r   )r   rv   r   r   s      r   impl_xpuz4register_run_and_save_rng_state_op.<locals>.impl_xpu   s$    yy&&("d*=f*===r   c                 p    	d}t        ||      }||vrt        d|       ||   } || g|i |S N)rB   ra   ry   rv   zBackend not supported for r   r\   )
r   r   r   impl_mapr   implr   r   r   r   s
         r   impl_backend_selectz?register_run_and_save_rng_state_op.<locals>.impl_backend_select   s`     	
 D&)! #=fX!FGGB((((r   c                 L    | 5   |g|i |cd d d        S # 1 sw Y   y xY wr   r   )moder   r   r   r   s       r   impl_fake_tensor_modezAregister_run_and_save_rng_state_op.<locals>.impl_fake_tensor_mode   s/      	<&r;D;F;	< 	< 	<s   #c                 :    |g|i |}t        j                  | j                  j                  |g|      }t        j                  | j                  j                  |      }| j                  j	                  d	||      }t        ||d | j                        S Ncall_functionconstanttracer)pytreetree_mapr   unwrap_proxycreate_proxyr   )
r   r   r   r   out
proxy_argsproxy_kwargs	out_proxyr   r   s
           r   impl_proxy_dispatch_modezDregister_run_and_save_rng_state_op.<locals>.impl_proxy_dispatch_mode   s    !"6t6v6__T[[%=%={T{K
t{{'?'?HKK,,3Z
	 !i$t{{SSr   )r   py_implr   Autogradr   CUDACPUHPUXPUBackendSelectr
   r   )	r   r   r   r   r   r   r   r   r   s	      @@@@@@r   "register_run_and_save_rng_state_opr      s@   90 9 018"";#7#78 !7M ##K$4$45? 6? ##KOO4: 5: ##KOO4Q 5Q
 ##KOO4> 5> ##K$=$=>) ?) ##N3< 4<
 ##$:;T <T "!r   c                    	  G d dt               }  |        	 	j                  t        j                        t	        	d             	j                  t        j
                        d        	j                  t        j                        d        	j                  t        j                        d        	j                  t        j                        d        	j                  t              	fd	       }	j                  t        j                        fd
       }	j                  t              d        }	j                  	fd       }	S )Nc                   (     e Zd Z fdZ fdZ xZS )7register_run_with_rng_state_op.<locals>.RunWithRngStatec                 (    t         |   dd       y )Nrun_with_rng_stateTr   r   r   s    r   r   z@register_run_with_rng_state_op.<locals>.RunWithRngState.__init__       G1TBr   c                 ,    t        |   ||g|i |S r   r   )r   	rng_stater   r   r   r   s        r   r   z@register_run_with_rng_state_op.<locals>.RunWithRngState.__call__   s    7#IrCDCFCCr   r   r   s   @r   RunWithRngStater      s    	C	D 	Dr   r   Tr   c                     t         j                  j                         }t         j                  j                  | j	                                 ||i |}t         j                  j                  |       |S r   )r   rB   r   set_rng_statera   r   r   r   r   current_stater   s         r   r   z1register_run_with_rng_state_op.<locals>.impl_cuda   sR    

002

  1$!&!

  /
r   c                     t        j                         }t        j                  |         ||i |}t        j                  |       |S r   )r   r   r   r   s         r   r   z0register_run_with_rng_state_op.<locals>.impl_cpu   s@    ++-I&$!&!M*
r   c                    t        t        d      rft        j                  j                         }t        j                  j	                  |         ||i |}t        j                  j	                  |       |S t        d      r   )r   r   ry   r   r   r   r   s         r   r   z0register_run_with_rng_state_op.<locals>.impl_hpu   sb    5% !II335MII##I.d%f%CII##M2JOPPr   c                     t         j                  j                         }t         j                  j                  |         ||i |}t         j                  j                  |       |S r   )r   rv   r   r   r   s         r   r   z0register_run_with_rng_state_op.<locals>.impl_xpu  sL    		//1			*$!&!		.
r   c                 |   t               5   	||g|i |}d d d        t        j                  | j                  j                  ||g|      }t        j                  | j                  j                  |      }| j                  j                  d	||      }t        |d | j                        S # 1 sw Y   xY wr   r   r   r   r   r   r   r   )
r   r   r   r   r   r   r   r   r   r   s
            r   r   z@register_run_with_rng_state_op.<locals>.impl_proxy_dispatch_mode  s     )* 	E$YDTDVDC	E__T[[%=%=	2?UPT?UV
t{{'?'?HKK,,/\
	 !i$t{{SS	E 	Es   B22B;c                 r    	
d}t        ||      }||vrt        d|       ||   } || |g|i |S r   r   )r   r   r   r   r   r   r   r   r   r   r   s          r   r   z;register_run_with_rng_state_op.<locals>.impl_backend_select  sb     	
 D&)! #=fX!FGGIr3D3F33r   c                 B    | 5   ||i |cd d d        S # 1 sw Y   y xY wr   r   )r   r   r   r   r   s        r   r   z=register_run_with_rng_state_op.<locals>.impl_fake_tensor_mode)  )      	't&v&	' 	' 	'   c                     | j                  |      }| j                  |      }| j                  |      }| j                         5   	||g|i |}| j                  |      cd d d        S # 1 sw Y   y xY wr   unwrap_tensorsredispatch_to_nextwrap_tensors)
ctxr   r   r   r   unwrapped_rng_stateunwrapped_argsunwrapped_kwargsr   r   s
            r   impl_functionalz7register_run_with_rng_state_op.<locals>.impl_functional0  s    !00;++D1--f5##% 	)$#R*8<LC ##C(		) 	) 	)s   A--A6)r   r   r   r   r   r   r   r   r   r   r   r
   py_functionalize_impl)
r   r   r   r   r   r   r   r   r   r   s
        @@@@@r   register_run_with_rng_state_opr      s\   D- D )*4{334 !3DI  0 01 2 0 1 0Q 1Q 0 1  67
T 8
T  9 9:4 ;4 /' 0' --	) .	) r   c                      G d dt               }  |         j                  t        j                        t	        d             j                  t        j
                        d dd       j                  t        j                        d dfd
       }j                  t              d dd       }j                  t              d dfd	
       }j                  d dfd

       }S )Nc                   .     e Zd Z fdZdd fd
Z xZS )Jregister_graphsafe_run_with_rng_state_op.<locals>.GraphSafeRunWithRngStatec                 $    t         |   d       y )Ngraphsafe_run_with_rng_stater   r   s    r   r   zSregister_graphsafe_run_with_rng_state_op.<locals>.GraphSafeRunWithRngState.__init__E  s    G;<r   Nr   c                .    t        |   |g|d|i|S Nr   r   )r   r   r   r   r   r   s        r   r   zSregister_graphsafe_run_with_rng_state_op.<locals>.GraphSafeRunWithRngState.__call__H  s"    7#BMMMfMMr   r   r   s   @r   GraphSafeRunWithRngStater   D  s    	= 15 	N 	Nr   r   Tr   r   c                    |j                   j                  }t        j                  j                  |   }|j                         }|j                  |        | |i |}|j                  |       |S r   )r   indexr   rB   default_generatorsgraphsafe_get_stategraphsafe_set_state)r   r   r   r   
device_idx	generatorr   r   s           r   r   z;register_graphsafe_run_with_rng_state_op.<locals>.impl_cudaR  sg     %%++
JJ11*=	!557%%i0$!&!%%m4
r   c                ^    t        ||      }|dk7  rt        d|        | g|d|i|S )NrB   z6GraphSafe RNG operations only supported for CUDA, got r   r   )r   r   r   r   r   r   s        r   r   zEregister_graphsafe_run_with_rng_state_op.<locals>.impl_backend_select^  sJ    D&)V HQ  BdBiB6BBr   c                B    | 5   ||i |cd d d        S # 1 sw Y   y xY wr   r   )r   r   r   r   r   s        r   r   zGregister_graphsafe_run_with_rng_state_op.<locals>.impl_fake_tensor_modeg  s'     	't&v&	' 	' 	'r   c                   t               5   	|g|d|i|}d d d        t        j                  | j                  j                  |g|      }t        j                  | j                  j                  d|i|      }| j                  j                  d	||      }t        |d | j                        S # 1 sw Y   xY w)Nr   r   r   r   )
r   r   r   r   r   r   r   r   r   r   s
            r   r   zJregister_graphsafe_run_with_rng_state_op.<locals>.impl_proxy_dispatch_model  s    (* 	Y.rXDXIXQWXC	Y__T[[%=%={T{K
KK$${I&H&H
 KK,,9:|
	 !i$t{{SS	Y 	Ys   B66B?c                    || j                  |      nd }| j                  |      }| j                  |      }| j                         5   	|g|d|i|}| j                  |      cd d d        S # 1 sw Y   y xY wr   r   )
r   r   r   r   r   r   r   r   r   r   s
            r   r   zAregister_graphsafe_run_with_rng_state_op.<locals>.impl_functionaly  s     .7-BCy) 	 ++D1--f5##% 	).#/BFVC ##C(		) 	) 	)s   	A22A;
r   r   r   r   r   r   r   r
   r   r   )r   r   r   r   r   r   r   s        @@r   (register_graphsafe_run_with_rng_state_opr  C  s   N#6 N $<#= > (()=)=> !=dS "))+*:*:;'+ 	 <	 "))+*C*CD15 C EC ")).99= ' :' "))*@A<@ 
T B
T "7726 ) 8) ('r   c                      G d dt               }  |         j                  t        j                        t	        d             j                  t        j
                        d        j                  t        j                        fd       }j                  t              d        }j                  t              fd       }j                  fd	       }S )
u  
    Register a higher-order operator for DTensor distributed random operations.
    Takes pre-computed integer offsets (start_offset_incr, end_offset_incr), an op,
    and args. Internally adjusts the RNG state using the offsets before/after
    running the op.

    The offsets are computed at trace time from the DTensorSpec, so the compiled graph
    contains only plain integers — no DTensorSpec or DeviceMesh objects.
    c                   (     e Zd Z fdZ fdZ xZS )4register_run_dtensor_rng_op.<locals>.RunDTensorRngOpc                 (    t         |   dd       y )Nrun_dtensor_rng_opTr   r   r   s    r   r   z=register_run_dtensor_rng_op.<locals>.RunDTensorRngOp.__init__  r   r   c                 .    t        |   |||g|i |S r   r   )r   start_offset_incrend_offset_incrr   r   r   r   s         r   r   z=register_run_dtensor_rng_op.<locals>.RunDTensorRngOp.__call__  s,    7#!?B9=AG r   r   r   s   @r   RunDTensorRngOpr    s    	C	 	r   r
  Tr   c                    ddl m}  |t        j                  j	                               }|j
                  j                         }|| z   |_        |rt        |d   d      r|d   j                  n3t        j                  dt        j                  j                                }t        j                  j                  |gd      5  t        j                  j                  |j                         	  ||i |}	||z   |_        	 d d d        t        j                  j                  |j                         	S # ||z   |_        w xY w# 1 sw Y   CxY w)Nr   )_PhiloxStater   zcuda:rB   )rf   device_type) torch.distributed.tensor._randomr  r   rB   r   rW   cloner   r   rD   rb   rc   r   state)
r  r	  r   r   r   r  r  
old_offsetr   r   s
             r   r   z.register_run_dtensor_rng_op.<locals>.impl_cuda  s   AUZZ5578\\'')
!$55 Q2 GNNejj&?&?&A%BCD 	
 \\""F8"H 	<JJ$$U[[1<$)&))O;	< 	

  -
  *O;	< 	<s$   *E0D68
E6EEEc                 `    t        ||      }|dk7  rt        d| d       | ||g|i |S )NrB   z2run_dtensor_rng_op only supports CUDA device, got zF. This operator is designed for distributed random operations on CUDA.)r   r   )r  r	  r   r   r   r   r   s         r   r   z8register_run_dtensor_rng_op.<locals>.impl_backend_select  sT    D&)VDVH MW X  *ORQ$Q&QQr   c                 B    | 5   ||i |cd d d        S # 1 sw Y   y xY wr   r   )r   r  r	  r   r   r   s         r   r   z:register_run_dtensor_rng_op.<locals>.impl_fake_tensor_mode  r   r   c                    t               5   
|||g|i |}d d d        t        j                  | j                  j                  |||g|      }t        j                  | j                  j                  |      }| j                  j                  d
||      }	t        |	d | j                        S # 1 sw Y   xY wr   r   )r   r  r	  r   r   r   r   r   r   r   r  s             r   r   z=register_run_dtensor_rng_op.<locals>.impl_proxy_dispatch_mode  s     )* 	$!?B9=AGC	 __KK$$'8/2&UPT&U

 t{{'?'?HKK,,/\
	 !i$t{{SS	 	s   B44B=c                     | j                  |      }| j                  |      }| j                         5   	|||g|i |}| j                  |      cd d d        S # 1 sw Y   y xY wr   r   )
r   r  r	  r   r   r   r   r   r   r  s
            r   r   z4register_run_dtensor_rng_op.<locals>.impl_functional  s|    ++D1--f5##% 	)$!  	
 #C ##C(	) 	) 	)s   AA&r   )r
  r   r   r   r   r   r  s        @@r   register_run_dtensor_rng_opr    s    -  )*4{334 !3DI  0 01 2,  9 9:R ;R /' 0'  67T 8T  --) .) r   c                      t                y r   )rm   r   r   r   register_rng_primsr    s    r   r   )*typingr   r   torch.utils._pytreeutils_pytreer   r   torch._Cr   torch._higher_order_ops.utilsr   
torch._opsr   torch._prims_commonr   r	   torch._subclasses.fake_tensorr
   "torch.fx.experimental.proxy_tensorr   r   r   torch.typesr   r   r   r4   rh   r<   rS   rm   r   r   r   r   r   r  r   r  r  r  r   r   r   <module>r$     s      $ $    B * O 8 
 (%.A::AW::W,6r&?"D_D <= 35 D(N  HI cL 12 r   