Skip to content

utils.get_device() should include 'mps' #136

@Zihann73

Description

@Zihann73

https://pytorch.org/docs/master/tensor_attributes.html#torch-device
Now this function only return cpu or cuda. I failed to run some transformer based models on my MacOS due to this limitation.
image
I fixed it by calling model.to('mps').
image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions