@@ -40,6 +40,7 @@ def __init__(
4040 batch : Optional [Any ] = None ,
4141 index_in_batch : Optional [int ] = None ,
4242 invocation_result : Optional [_invocation .InvocationResult ] = None ,
43+ copy : bool = False ,
4344 ):
4445 if layout is None :
4546 layout = ""
@@ -52,6 +53,7 @@ def __init__(
5253 self ._index_in_batch = index_in_batch
5354 self ._invocation_result = None
5455 self ._wraps_external_data = False
56+ copied = False
5557
5658 from . import fn
5759
@@ -77,13 +79,15 @@ def __init__(
7779 self .assign (data )
7880 self ._wraps_external_data = data ._wraps_external_data
7981 else :
80- self .assign (data .to_device (device ).evaluate ())
82+ dev = data .to_device (device ).evaluate ()
83+ if dev is not self :
84+ copied = True
85+ self .assign (dev )
86+ self ._wraps_external_data = not copied
8187 else :
8288 self .assign (fn .cast (data , dtype , device = device ).evaluate ())
83- return
8489 elif isinstance (data , TensorSlice ):
8590 self ._slice = data
86- return
8791 elif hasattr (data , "__dlpack__" ):
8892 self ._backend = TensorCPU (data , layout )
8993 self ._wraps_external_data = True
@@ -99,10 +103,12 @@ def __init__(
99103 layout ,
100104 False ,
101105 )
106+ copied = True
102107 self ._wraps_external_data = False
103108 self ._dtype = dtype
104109 else :
105110 self ._backend = TensorCPU (np .array (data ), layout , False )
111+ copied = True
106112 self ._wraps_external_data = False
107113
108114 if device is not None :
@@ -131,6 +137,9 @@ def __init__(
131137 if _eval_mode .EvalMode .current ().value >= _eval_mode .EvalMode .eager .value :
132138 self .evaluate ()
133139
140+ if copy and self ._backend is not None and not copied :
141+ self .assign (self .to_device (self .device , force_copy = True ).evaluate ())
142+
134143 def _is_external (self ) -> bool :
135144 return self ._wraps_external_data
136145
@@ -150,8 +159,8 @@ def device(self) -> Device:
150159 else :
151160 raise RuntimeError ("Device not set" )
152161
153- def to_device (self , device : Device ) -> "Tensor" :
154- if self .device == device :
162+ def to_device (self , device : Device , force_copy : bool = False ) -> "Tensor" :
163+ if self .device == device and not force_copy :
155164 return self
156165 else :
157166 with device :
@@ -548,3 +557,12 @@ def evaluate(self):
548557 from . import fn
549558
550559 return fn .tensor_subscript (self ._tensor , ** args ).evaluate ()
560+
561+
562+ def tensor (
563+ data : Any ,
564+ dtype : Optional [Any ] = None ,
565+ device : Optional [Device ] = None ,
566+ layout : Optional [str ] = None ,
567+ ):
568+ return Tensor (data , dtype = dtype , device = device , layout = layout , copy = True )
0 commit comments